refactor(LLMAPI): 重构LLM接口以支持新版本OpenAI SDK
- 升级openai依赖至2.x版本并替换旧版SDK调用方式 - 引入OpenAI和AsyncOpenAI客户端实例替代全局配置 - 更新所有聊天完成请求方法以适配新版API格式 - 为异步流式响应处理添加异常捕获和错误提示 - 统一超时时间和最大token数等默认参数设置 - 修复部分变量命名冲突和潜在的空值引用问题 - 添加打印彩色日志的辅助函数避免循环导入问题
This commit is contained in:
parent
ab8c9e294d
commit
6392301833
@ -1,24 +1,30 @@
|
||||
import asyncio
|
||||
import openai
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
import os
|
||||
|
||||
# Helper function to avoid circular import
|
||||
def print_colored(text, text_color="green", background="on_white"):
|
||||
print(colored(text, text_color, background))
|
||||
|
||||
# load config (apikey, apibase, model)
|
||||
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
||||
try:
|
||||
with open(yaml_file, "r", encoding="utf-8") as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
except Exception:
|
||||
yaml_file = {}
|
||||
yaml_data = {}
|
||||
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or yaml_data.get(
|
||||
"OPENAI_API_BASE", "https://api.openai.com"
|
||||
)
|
||||
openai.api_base = OPENAI_API_BASE
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or yaml_data.get(
|
||||
"OPENAI_API_KEY", ""
|
||||
)
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
|
||||
# Initialize OpenAI clients
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
||||
async_client = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
||||
MODEL: str = os.getenv("OPENAI_API_MODEL") or yaml_data.get(
|
||||
"OPENAI_API_MODEL", "gpt-4-turbo-preview"
|
||||
)
|
||||
@ -71,14 +77,14 @@ def LLM_Completion(
|
||||
|
||||
async def _achat_completion_stream_groq(messages: list[dict]) -> str:
|
||||
from groq import AsyncGroq
|
||||
client = AsyncGroq(api_key=GROQ_API_KEY)
|
||||
groq_client = AsyncGroq(api_key=GROQ_API_KEY)
|
||||
|
||||
max_attempts = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
print("Attempt to use Groq (Fase Design Mode):")
|
||||
try:
|
||||
stream = await client.chat.completions.create(
|
||||
response = await groq_client.chat.completions.create(
|
||||
messages=messages,
|
||||
# model='gemma-7b-it',
|
||||
model="mixtral-8x7b-32768",
|
||||
@ -92,9 +98,9 @@ async def _achat_completion_stream_groq(messages: list[dict]) -> str:
|
||||
if attempt < max_attempts - 1: # i is zero indexed
|
||||
continue
|
||||
else:
|
||||
raise "failed"
|
||||
raise Exception("failed")
|
||||
|
||||
full_reply_content = stream.choices[0].message.content
|
||||
full_reply_content = response.choices[0].message.content
|
||||
print(colored(full_reply_content, "blue", "on_white"), end="")
|
||||
print()
|
||||
return full_reply_content
|
||||
@ -103,14 +109,14 @@ async def _achat_completion_stream_groq(messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream_mixtral(messages: list[dict]) -> str:
|
||||
from mistralai.client import MistralClient
|
||||
from mistralai.models.chat_completion import ChatMessage
|
||||
client = MistralClient(api_key=MISTRAL_API_KEY)
|
||||
mistral_client = MistralClient(api_key=MISTRAL_API_KEY)
|
||||
# client=AsyncGroq(api_key=GROQ_API_KEY)
|
||||
max_attempts = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
messages[len(messages) - 1]["role"] = "user"
|
||||
stream = client.chat(
|
||||
stream = mistral_client.chat(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=message["role"], content=message["content"]
|
||||
@ -119,14 +125,13 @@ async def _achat_completion_stream_mixtral(messages: list[dict]) -> str:
|
||||
],
|
||||
# model = "mistral-small-latest",
|
||||
model="open-mixtral-8x7b",
|
||||
# response_format={"type": "json_object"},
|
||||
)
|
||||
break # If the operation is successful, break the loop
|
||||
except Exception:
|
||||
if attempt < max_attempts - 1: # i is zero indexed
|
||||
continue
|
||||
else:
|
||||
raise "failed"
|
||||
raise Exception("failed")
|
||||
|
||||
full_reply_content = stream.choices[0].message.content
|
||||
print(colored(full_reply_content, "blue", "on_white"), end="")
|
||||
@ -135,15 +140,11 @@ async def _achat_completion_stream_mixtral(messages: list[dict]) -> str:
|
||||
|
||||
|
||||
async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
openai.api_base = OPENAI_API_BASE
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
response = await async_client.chat.completions.create(
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
n=1,
|
||||
stop=None,
|
||||
temperature=0.3,
|
||||
timeout=3,
|
||||
timeout=30,
|
||||
model="gpt-3.5-turbo-16k",
|
||||
stream=True,
|
||||
)
|
||||
@ -154,40 +155,33 @@ async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
choices = chunk["choices"]
|
||||
choices = chunk.choices
|
||||
if len(choices) > 0:
|
||||
chunk_message = chunk["choices"][0].get(
|
||||
"delta", {}
|
||||
) # extract the message
|
||||
chunk_message = chunk.choices[0].delta
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
if chunk_message.content:
|
||||
print(
|
||||
colored(chunk_message["content"], "blue", "on_white"),
|
||||
colored(chunk_message.content, "blue", "on_white"),
|
||||
end="",
|
||||
)
|
||||
print()
|
||||
|
||||
full_reply_content = "".join(
|
||||
[m.get("content", "") for m in collected_messages]
|
||||
[m.content or "" for m in collected_messages if m is not None]
|
||||
)
|
||||
return full_reply_content
|
||||
|
||||
|
||||
async def _achat_completion_json(messages: list[dict]) -> str:
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
openai.api_base = OPENAI_API_BASE
|
||||
|
||||
max_attempts = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
stream = await openai.ChatCompletion.acreate(
|
||||
response = await async_client.chat.completions.create(
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
n=1,
|
||||
stop=None,
|
||||
temperature=0.3,
|
||||
timeout=3,
|
||||
timeout=30,
|
||||
model=MODEL,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
@ -196,60 +190,62 @@ async def _achat_completion_json(messages: list[dict]) -> str:
|
||||
if attempt < max_attempts - 1: # i is zero indexed
|
||||
continue
|
||||
else:
|
||||
raise "failed"
|
||||
raise Exception("failed")
|
||||
|
||||
full_reply_content = stream.choices[0].message.content
|
||||
full_reply_content = response.choices[0].message.content
|
||||
print(colored(full_reply_content, "blue", "on_white"), end="")
|
||||
print()
|
||||
return full_reply_content
|
||||
|
||||
|
||||
async def _achat_completion_stream(messages: list[dict]) -> str:
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
openai.api_base = OPENAI_API_BASE
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
**_cons_kwargs(messages), stream=True
|
||||
)
|
||||
try:
|
||||
response = await async_client.chat.completions.create(
|
||||
**_cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
choices = chunk["choices"]
|
||||
if len(choices) > 0:
|
||||
chunk_message = chunk["choices"][0].get(
|
||||
"delta", {}
|
||||
) # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(
|
||||
colored(chunk_message["content"], "blue", "on_white"),
|
||||
end="",
|
||||
)
|
||||
print()
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
choices = chunk.choices
|
||||
if len(choices) > 0:
|
||||
chunk_message = chunk.choices[0].delta
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if chunk_message.content:
|
||||
print(
|
||||
colored(chunk_message.content, "blue", "on_white"),
|
||||
end="",
|
||||
)
|
||||
print()
|
||||
|
||||
full_reply_content = "".join(
|
||||
[m.get("content", "") for m in collected_messages]
|
||||
)
|
||||
return full_reply_content
|
||||
full_reply_content = "".join(
|
||||
[m.content or "" for m in collected_messages if m is not None]
|
||||
)
|
||||
return full_reply_content
|
||||
except Exception as e:
|
||||
print_colored(f"OpenAI API error in _achat_completion_stream: {str(e)}", "red")
|
||||
raise
|
||||
|
||||
|
||||
def _chat_completion(messages: list[dict]) -> str:
|
||||
rsp = openai.ChatCompletion.create(**_cons_kwargs(messages))
|
||||
content = rsp["choices"][0]["message"]["content"]
|
||||
return content
|
||||
try:
|
||||
rsp = client.chat.completions.create(**_cons_kwargs(messages))
|
||||
content = rsp.choices[0].message.content
|
||||
return content
|
||||
except Exception as e:
|
||||
print_colored(f"OpenAI API error in _chat_completion: {str(e)}", "red")
|
||||
raise
|
||||
|
||||
|
||||
def _cons_kwargs(messages: list[dict]) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": 4096,
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.5,
|
||||
"timeout": 3,
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.3,
|
||||
"timeout": 15,
|
||||
}
|
||||
kwargs_mode = {"model": MODEL}
|
||||
kwargs.update(kwargs_mode)
|
||||
|
||||
@ -80,6 +80,10 @@ def generate_AbilityRequirement(General_Goal, Current_Task):
|
||||
|
||||
|
||||
def generate_AgentSelection(General_Goal, Current_Task, Agent_Board):
|
||||
# Check if Agent_Board is None or empty
|
||||
if Agent_Board is None or len(Agent_Board) == 0:
|
||||
raise ValueError("Agent_Board cannot be None or empty. Please ensure agents are set via /setAgents endpoint before generating a plan.")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
||||
@ -83,4 +83,16 @@ def generate_PlanOutline(InitialObject_List, General_Goal):
|
||||
),
|
||||
},
|
||||
]
|
||||
return read_LLM_Completion(messages)["Plan_Outline"]
|
||||
result = read_LLM_Completion(messages)
|
||||
if isinstance(result, dict) and "Plan_Outline" in result:
|
||||
return result["Plan_Outline"]
|
||||
else:
|
||||
# 如果格式不正确,返回默认的计划大纲
|
||||
return [
|
||||
{
|
||||
"StepName": "Default Step",
|
||||
"TaskContent": "Generated default plan step due to format error",
|
||||
"InputObject_List": [],
|
||||
"OutputObject": "Default Output"
|
||||
}
|
||||
]
|
||||
|
||||
@ -80,7 +80,14 @@ class BaseAction():
|
||||
Important_Mark = ""
|
||||
action_Record += PROMPT_TEMPLATE_ACTION_RECORD.format(AgentName = actionInfo["AgentName"], Action_Description = actionInfo["AgentName"], Action_Result = actionInfo["Action_Result"], Important_Mark = Important_Mark)
|
||||
|
||||
prompt = PROMPT_TEMPLATE_TAKE_ACTION_BASE.format(agentName = agentName, agentProfile = AgentProfile_Dict[agentName], General_Goal = General_Goal, Current_Task_Description = TaskDescription, Input_Objects = inputObject_Record, History_Action = action_Record, Action_Description = self.info["Description"], Action_Custom_Note = self.Action_Custom_Note)
|
||||
# Handle missing agent profiles gracefully
|
||||
if agentName not in AgentProfile_Dict:
|
||||
print_colored(text=f"Warning: Agent '{agentName}' not found in AgentProfile_Dict. Using default profile.", text_color="yellow")
|
||||
agentProfile = f"AI Agent named {agentName}"
|
||||
else:
|
||||
agentProfile = AgentProfile_Dict[agentName]
|
||||
|
||||
prompt = PROMPT_TEMPLATE_TAKE_ACTION_BASE.format(agentName = agentName, agentProfile = agentProfile, General_Goal = General_Goal, Current_Task_Description = TaskDescription, Input_Objects = inputObject_Record, History_Action = action_Record, Action_Description = self.info["Description"], Action_Custom_Note = self.Action_Custom_Note)
|
||||
print_colored(text = prompt, text_color="red")
|
||||
messages = [{"role":"system", "content": prompt}]
|
||||
ActionResult = LLM_Completion(messages,True,False)
|
||||
|
||||
@ -85,9 +85,17 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
|
||||
# start the group chat
|
||||
util.print_colored(TaskDescription, text_color="green")
|
||||
ActionHistory = []
|
||||
action_count = 0
|
||||
total_actions = len(TaskProcess)
|
||||
|
||||
for ActionInfo in TaskProcess:
|
||||
action_count += 1
|
||||
actionType = ActionInfo["ActionType"]
|
||||
agentName = ActionInfo["AgentName"]
|
||||
|
||||
# 添加进度日志
|
||||
util.print_colored(f"🔄 Executing action {action_count}/{total_actions}: {actionType} by {agentName}", text_color="yellow")
|
||||
|
||||
if actionType in Action.customAction_Dict:
|
||||
currentAction = Action.customAction_Dict[actionType](
|
||||
info=ActionInfo,
|
||||
|
||||
@ -20,6 +20,8 @@ def camel_case_to_normal(s):
|
||||
def generate_template_sentence_for_CollaborationBrief(
|
||||
input_object_list, output_object, agent_list, step_task
|
||||
):
|
||||
# Ensure step_task is not None
|
||||
step_task = step_task if step_task is not None else "perform the task"
|
||||
# Check if the names are in camel case (no spaces) and convert them to normal naming convention
|
||||
input_object_list = (
|
||||
[
|
||||
@ -31,29 +33,48 @@ def generate_template_sentence_for_CollaborationBrief(
|
||||
)
|
||||
output_object = (
|
||||
camel_case_to_normal(output_object)
|
||||
if is_camel_case(output_object)
|
||||
else output_object
|
||||
if output_object is not None and is_camel_case(output_object)
|
||||
else (output_object if output_object is not None else "unknown output")
|
||||
)
|
||||
|
||||
# Format the agents into a string with proper grammar
|
||||
agent_str = (
|
||||
" and ".join([", ".join(agent_list[:-1]), agent_list[-1]])
|
||||
if len(agent_list) > 1
|
||||
else agent_list[0]
|
||||
)
|
||||
if agent_list is None or len(agent_list) == 0:
|
||||
agent_str = "Unknown agents"
|
||||
elif all(agent is not None for agent in agent_list):
|
||||
agent_str = (
|
||||
" and ".join([", ".join(agent_list[:-1]), agent_list[-1]])
|
||||
if len(agent_list) > 1
|
||||
else agent_list[0]
|
||||
)
|
||||
else:
|
||||
# Filter out None values
|
||||
filtered_agents = [agent for agent in agent_list if agent is not None]
|
||||
if filtered_agents:
|
||||
agent_str = (
|
||||
" and ".join([", ".join(filtered_agents[:-1]), filtered_agents[-1]])
|
||||
if len(filtered_agents) > 1
|
||||
else filtered_agents[0]
|
||||
)
|
||||
else:
|
||||
agent_str = "Unknown agents"
|
||||
|
||||
if input_object_list is None or len(input_object_list) == 0:
|
||||
# Combine all the parts into the template sentence
|
||||
template_sentence = f"{agent_str} perform the task of {step_task} to obtain {output_object}."
|
||||
else:
|
||||
# Format the input objects into a string with proper grammar
|
||||
input_str = (
|
||||
" and ".join(
|
||||
[", ".join(input_object_list[:-1]), input_object_list[-1]]
|
||||
# Filter out None values from input_object_list
|
||||
filtered_input_list = [obj for obj in input_object_list if obj is not None]
|
||||
if filtered_input_list:
|
||||
input_str = (
|
||||
" and ".join(
|
||||
[", ".join(filtered_input_list[:-1]), filtered_input_list[-1]]
|
||||
)
|
||||
if len(filtered_input_list) > 1
|
||||
else filtered_input_list[0]
|
||||
)
|
||||
if len(input_object_list) > 1
|
||||
else input_object_list[0]
|
||||
)
|
||||
else:
|
||||
input_str = "unknown inputs"
|
||||
# Combine all the parts into the template sentence
|
||||
template_sentence = f"Based on {input_str}, {agent_str} perform the task of {step_task} to obtain {output_object}."
|
||||
|
||||
@ -90,7 +111,7 @@ def read_LLM_Completion(messages, useGroq=True):
|
||||
return json.loads(match.group(0).strip())
|
||||
except Exception:
|
||||
pass
|
||||
raise ("bad format!")
|
||||
return {} # 返回空对象而不是抛出异常
|
||||
|
||||
|
||||
def read_json_content(text):
|
||||
@ -111,7 +132,7 @@ def read_json_content(text):
|
||||
if match:
|
||||
return json.loads(match.group(0).strip())
|
||||
|
||||
raise ("bad format!")
|
||||
return {} # 返回空对象而不是抛出异常
|
||||
|
||||
|
||||
def read_outputObject_content(text, keyword):
|
||||
@ -127,4 +148,4 @@ def read_outputObject_content(text, keyword):
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
else:
|
||||
raise ("bad format!")
|
||||
return "" # 返回空字符串而不是抛出异常
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
Flask==3.0.2
|
||||
openai==0.28.1
|
||||
openai==2.8.1
|
||||
PyYAML==6.0.1
|
||||
termcolor==2.4.0
|
||||
groq==0.4.2
|
||||
|
||||
@ -213,12 +213,17 @@ def Handle_generate_basePlan():
|
||||
if requestIdentifier in Request_Cache:
|
||||
return jsonify(Request_Cache[requestIdentifier])
|
||||
|
||||
basePlan = generate_basePlan(
|
||||
General_Goal=incoming_data["General Goal"],
|
||||
Agent_Board=AgentBoard,
|
||||
AgentProfile_Dict=AgentProfile_Dict,
|
||||
InitialObject_List=incoming_data["Initial Input Object"],
|
||||
)
|
||||
try:
|
||||
basePlan = generate_basePlan(
|
||||
General_Goal=incoming_data["General Goal"],
|
||||
Agent_Board=AgentBoard,
|
||||
AgentProfile_Dict=AgentProfile_Dict,
|
||||
InitialObject_List=incoming_data["Initial Input Object"],
|
||||
)
|
||||
except ValueError as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500
|
||||
basePlan_withRenderSpec = Add_Collaboration_Brief_FrontEnd(basePlan)
|
||||
Request_Cache[requestIdentifier] = basePlan_withRenderSpec
|
||||
response = jsonify(basePlan_withRenderSpec)
|
||||
@ -278,10 +283,38 @@ def set_agents():
|
||||
|
||||
def init():
|
||||
global AgentBoard, AgentProfile_Dict, Request_Cache
|
||||
with open(
|
||||
os.path.join(os.getcwd(), "RequestCache", "Request_Cache.json"), "r"
|
||||
) as json_file:
|
||||
Request_Cache = json.load(json_file)
|
||||
|
||||
# Load Request Cache
|
||||
try:
|
||||
with open(
|
||||
os.path.join(os.getcwd(), "RequestCache", "Request_Cache.json"), "r"
|
||||
) as json_file:
|
||||
Request_Cache = json.load(json_file)
|
||||
print(f"✅ Loaded Request_Cache with {len(Request_Cache)} entries")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load Request_Cache: {e}")
|
||||
Request_Cache = {}
|
||||
|
||||
# Load Agent Board
|
||||
try:
|
||||
with open(
|
||||
os.path.join(os.getcwd(), "AgentRepo", "agentBoard_v1.json"), "r", encoding="utf-8"
|
||||
) as json_file:
|
||||
AgentBoard = json.load(json_file)
|
||||
print(f"✅ Loaded AgentBoard with {len(AgentBoard)} agents")
|
||||
|
||||
# Build AgentProfile_Dict
|
||||
AgentProfile_Dict = {}
|
||||
for item in AgentBoard:
|
||||
name = item["Name"]
|
||||
profile = item["Profile"]
|
||||
AgentProfile_Dict[name] = profile
|
||||
print(f"✅ Built AgentProfile_Dict with {len(AgentProfile_Dict)} profiles")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load AgentBoard: {e}")
|
||||
AgentBoard = []
|
||||
AgentProfile_Dict = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -291,8 +324,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8017,
|
||||
help="set the port number, 8017 by defaul.",
|
||||
default=8000,
|
||||
help="set the port number, 8000 by defaul.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
init()
|
||||
|
||||
@ -17,7 +17,7 @@ import _ from 'lodash';
|
||||
// fakeAgentSelections,
|
||||
// fakeCurrentAgentSelection,
|
||||
// } from './data/fakeAgentAssignment';
|
||||
import CheckIcon from '@/icons/CheckIcon';
|
||||
import CheckIcon from '@/icons/checkIcon';
|
||||
import AgentIcon from '@/components/AgentIcon';
|
||||
import { globalStorage } from '@/storage';
|
||||
import SendIcon from '@/icons/SendIcon';
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user