liailing1026 23db6fc4a1 feat
2025-12-07 17:18:10 +08:00

331 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import openai
import yaml
from termcolor import colored
import os
# load config (apikey, apibase, model)
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
yaml_data = {}
try:
with open(yaml_file, "r", encoding="utf-8") as file:
yaml_data = yaml.safe_load(file)
except Exception:
yaml_data = {}
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or yaml_data.get(
"OPENAI_API_BASE", "https://api.openai.com"
)
openai.api_base = OPENAI_API_BASE
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or yaml_data.get(
"OPENAI_API_KEY", ""
)
openai.api_key = OPENAI_API_KEY
MODEL: str = os.getenv("OPENAI_API_MODEL") or yaml_data.get(
"OPENAI_API_MODEL", "gpt-4-turbo-preview"
)
FAST_DESIGN_MODE: bool = os.getenv("FAST_DESIGN_MODE")
if FAST_DESIGN_MODE is None:
FAST_DESIGN_MODE = yaml_data.get("FAST_DESIGN_MODE", False)
else:
FAST_DESIGN_MODE = FAST_DESIGN_MODE.lower() in ["true", "1", "yes"]
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or yaml_data.get("GROQ_API_KEY", "")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") or yaml_data.get(
"MISTRAL_API_KEY", ""
)
# for LLM completion
def LLM_Completion(
messages: list[dict], stream: bool = True, useGroq: bool = True
) -> str:
# 增强消息验证:确保所有消息的 role 和 content 非空且不是空白字符串
if not messages or len(messages) == 0:
raise ValueError("Messages list is empty")
# print(f"[DEBUG] LLM_Completion received {len(messages)} messages", flush=True)
for i, msg in enumerate(messages):
if not isinstance(msg, dict):
raise ValueError(f"Message at index {i} is not a dictionary")
if not msg.get("role") or str(msg.get("role")).strip() == "":
raise ValueError(f"Message at index {i} has empty 'role'")
if not msg.get("content") or str(msg.get("content")).strip() == "":
raise ValueError(f"Message at index {i} has empty 'content'")
# 额外验证确保content不会因为格式化问题变成空
content = str(msg.get("content")).strip()
if len(content) < 10: # 设置最小长度阈值
print(f"[WARNING] Message at index {i} has very short content: '{content}'", flush=True)
# 修改1
if not GROQ_API_KEY:
useGroq = False
elif not useGroq or not FAST_DESIGN_MODE:
force_gpt4 = True
useGroq = False
else:
force_gpt4 = False
useGroq = True
if stream:
try:
loop = asyncio.get_event_loop()
except RuntimeError as ex:
if "There is no current event loop in thread" in str(ex):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if useGroq:
if force_gpt4:
return loop.run_until_complete(
_achat_completion_json(messages=messages)
)
else:
return loop.run_until_complete(
_achat_completion_stream_groq(messages=messages)
)
else:
return loop.run_until_complete(
_achat_completion_stream(messages=messages)
)
# return asyncio.run(_achat_completion_stream(messages = messages))
else:
return _chat_completion(messages=messages)
async def _achat_completion_stream_groq(messages: list[dict]) -> str:
from groq import AsyncGroq
client = AsyncGroq(api_key=GROQ_API_KEY)
max_attempts = 5
for attempt in range(max_attempts):
print("Attempt to use Groq (Fase Design Mode):")
try:
stream = await client.chat.completions.create(
messages=messages,
# model='gemma-7b-it',
model="mixtral-8x7b-32768",
# model='llama2-70b-4096',
temperature=0.3,
response_format={"type": "json_object"},
stream=False,
)
break
except Exception:
if attempt < max_attempts - 1: # i is zero indexed
continue
else:
raise "failed"
full_reply_content = stream.choices[0].message.content
print(colored(full_reply_content, "blue", "on_white"), end="")
print()
return full_reply_content
async def _achat_completion_stream_mixtral(messages: list[dict]) -> str:
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
client = MistralClient(api_key=MISTRAL_API_KEY)
# client=AsyncGroq(api_key=GROQ_API_KEY)
max_attempts = 5
for attempt in range(max_attempts):
try:
messages[len(messages) - 1]["role"] = "user"
stream = client.chat(
messages=[
ChatMessage(
role=message["role"], content=message["content"]
)
for message in messages
],
# model = "mistral-small-latest",
model="open-mixtral-8x7b",
# response_format={"type": "json_object"},
)
break # If the operation is successful, break the loop
except Exception:
if attempt < max_attempts - 1: # i is zero indexed
continue
else:
raise "failed"
full_reply_content = stream.choices[0].message.content
print(colored(full_reply_content, "blue", "on_white"), end="")
print()
return full_reply_content
async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
openai.api_key = OPENAI_API_KEY
openai.api_base = OPENAI_API_BASE
kwargs = {
"messages": messages,
"max_tokens": 4096,
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3,
"model": "gpt-3.5-turbo-16k",
"stream": True,
}
# print("[DEBUG] about to call acreate with kwargs:", type(kwargs), kwargs)
assert kwargs is not None, "kwargs is None right before acreate!"
assert isinstance(kwargs, dict), "kwargs must be dict!"
response = await openai.ChatCompletion.acreate(**kwargs)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get(
"delta", {}
) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(
colored(chunk_message["content"], "blue", "on_white"),
end="",
)
print()
full_reply_content = "".join(
[m.get("content", "") for m in collected_messages]
)
return full_reply_content
async def _achat_completion_json(messages: list[dict]) -> str:
openai.api_key = OPENAI_API_KEY
openai.api_base = OPENAI_API_BASE
max_attempts = 5
for attempt in range(max_attempts):
try:
stream = await openai.ChatCompletion.acreate(
messages=messages,
max_tokens=4096,
n=1,
stop=None,
temperature=0.3,
timeout=3,
model=MODEL,
response_format={"type": "json_object"},
)
break
except Exception:
if attempt < max_attempts - 1: # i is zero indexed
continue
else:
raise "failed"
full_reply_content = stream.choices[0].message.content
print(colored(full_reply_content, "blue", "on_white"), end="")
print()
return full_reply_content
async def _achat_completion_stream(messages: list[dict]) -> str:
# print(">>>> _achat_completion_stream 被调用", flush=True)
# print(">>>> messages 实参 =", messages, flush=True)
# print(">>>> messages 类型 =", type(messages), flush=True)
openai.api_key = OPENAI_API_KEY
openai.api_base = OPENAI_API_BASE
response = await openai.ChatCompletion.acreate(
**_cons_kwargs(messages), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get(
"delta", {}
) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(
colored(chunk_message["content"], "blue", "on_white"),
end="",
)
print()
full_reply_content = "".join(
[m.get("content", "") for m in collected_messages]
)
return full_reply_content
def _chat_completion(messages: list[dict]) -> str:
rsp = openai.ChatCompletion.create(**_cons_kwargs(messages))
content = rsp["choices"][0]["message"]["content"]
return content
def _cons_kwargs(messages: list[dict]) -> dict:
kwargs = {
"messages": messages,
"max_tokens": 4096,
"temperature": 0.5,
}
print("[DEBUG] kwargs =", kwargs)
assert isinstance(kwargs, dict), f"_cons_kwargs returned {type(kwargs)}, must be dict"
# 添加调试信息
print(f'[DEBUG] _cons_kwargs messages: {messages}', flush=True)
# 检查并修复消息中的null值
for i, msg in enumerate(messages):
# 确保msg是字典
if not isinstance(msg, dict):
print(f"[ERROR] Message {i} is not a dictionary: {msg}", flush=True)
messages[i] = {"role": "user", "content": str(msg) if msg is not None else ""}
continue
# 确保role和content存在且不为None
if "role" not in msg or msg["role"] is None:
print(f"[ERROR] Message {i} missing role, setting to 'user'", flush=True)
msg["role"] = "user"
else:
msg["role"] = str(msg["role"]).strip()
if "content" not in msg or msg["content"] is None:
print(f"[ERROR] Message {i} missing content, setting to empty string", flush=True)
msg["content"] = ""
else:
msg["content"] = str(msg["content"]).strip()
# 根据不同的API提供商调整参数
if "deepseek" in MODEL.lower():
# DeepSeek API特殊处理
print("[DEBUG] DeepSeek API detected, adjusting parameters", flush=True)
kwargs.pop("n", None) # 移除n参数DeepSeek可能不支持
if "timeout" in kwargs:
kwargs.pop("timeout", None)
# DeepSeek可能不支持stop参数
kwargs.pop("stop", None)
else:
# OpenAI兼容的API
kwargs["n"] = 1
kwargs["stop"] = None
kwargs["timeout"] = 3
kwargs["model"] = MODEL
# 确保messages列表中的每个元素都有有效的role和content
kwargs["messages"] = [msg for msg in messages if msg["role"] and msg["content"]]
print(f"[DEBUG] Final kwargs for API call: {kwargs.keys()}", flush=True)
return kwargs