feat:单个agent配置各自的apiurl、apimodel、apikey
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI, AsyncOpenAI, max_retries
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
import os
|
||||
@@ -21,6 +23,9 @@ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or yaml_data.get(
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or yaml_data.get(
|
||||
"OPENAI_API_KEY", ""
|
||||
)
|
||||
OPENAI_API_MODEL = os.getenv("OPENAI_API_MODEL") or yaml_data.get(
|
||||
"OPENAI_API_MODEL", ""
|
||||
)
|
||||
|
||||
# Initialize OpenAI clients
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
||||
@@ -41,8 +46,11 @@ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") or yaml_data.get(
|
||||
|
||||
# for LLM completion
|
||||
def LLM_Completion(
|
||||
messages: list[dict], stream: bool = True, useGroq: bool = True
|
||||
messages: list[dict], stream: bool = True, useGroq: bool = True,model_config: dict = None
|
||||
) -> str:
|
||||
if model_config:
|
||||
print_colored(f"Using model config: {model_config}", "blue")
|
||||
return _call_with_custom_config(messages,stream,model_config)
|
||||
if not useGroq or not FAST_DESIGN_MODE:
|
||||
force_gpt4 = True
|
||||
useGroq = False
|
||||
@@ -75,6 +83,82 @@ def LLM_Completion(
|
||||
return _chat_completion(messages=messages)
|
||||
|
||||
|
||||
def _call_with_custom_config(messages: list[dict], stream: bool, model_config: dict) ->str:
|
||||
"使用自定义配置调用API"
|
||||
api_url = model_config.get("apiUrl", OPENAI_API_BASE)
|
||||
api_key = model_config.get("apiKey", OPENAI_API_KEY)
|
||||
api_model = model_config.get("apiModel", OPENAI_API_MODEL)
|
||||
|
||||
temp_client = OpenAI(api_key=api_key, base_url=api_url)
|
||||
temp_async_client = AsyncOpenAI(api_key=api_key, base_url=api_url)
|
||||
|
||||
try:
|
||||
if stream:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError as ex:
|
||||
if "There is no current event loop in thread" in str(ex):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(
|
||||
_achat_completion_stream_custom(messages=messages, temp_async_client=temp_async_client, api_model=api_model)
|
||||
)
|
||||
else:
|
||||
response = temp_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=api_model,
|
||||
temperature=0.3,
|
||||
max_tokens=4096,
|
||||
timeout=180
|
||||
|
||||
)
|
||||
full_reply_content = response.choices[0].message.content
|
||||
print(colored(full_reply_content, "blue", "on_white"), end="")
|
||||
return full_reply_content
|
||||
except Exception as e:
|
||||
print_colored(f"Custom API error for model {api_model} :{str(e)}","red")
|
||||
raise
|
||||
|
||||
|
||||
async def _achat_completion_stream_custom(messages:list[dict], temp_async_client, api_model: str ) -> str:
|
||||
max_retries=3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await temp_async_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=api_model,
|
||||
temperature=0.3,
|
||||
max_tokens=4096,
|
||||
stream=True,
|
||||
timeout=180
|
||||
)
|
||||
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk)
|
||||
choices = chunk.choices
|
||||
if len(choices) > 0:
|
||||
chunk_message = chunk.choices[0].delta
|
||||
collected_messages.append(chunk_message)
|
||||
if chunk_message.content:
|
||||
print(colored(chunk_message.content, "blue", "on_white"), end="")
|
||||
print()
|
||||
full_reply_content = "".join(
|
||||
[m.content or "" for m in collected_messages if m is not None]
|
||||
)
|
||||
return full_reply_content
|
||||
except httpx.RemoteProtocolError as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = (attempt + 1) *2
|
||||
print_colored(f"⚠️ Stream connection interrupted (attempt {attempt+1}/{max_retries}). Retrying in {wait_time}s...", text_color="yellow")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
except Exception as e:
|
||||
print_colored(f"Custom API stream error for model {api_model} :{str(e)}","red")
|
||||
raise
|
||||
|
||||
|
||||
async def _achat_completion_stream_groq(messages: list[dict]) -> str:
|
||||
from groq import AsyncGroq
|
||||
groq_client = AsyncGroq(api_key=GROQ_API_KEY)
|
||||
@@ -144,7 +228,7 @@ async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
temperature=0.3,
|
||||
timeout=30,
|
||||
timeout=600,
|
||||
model="gpt-3.5-turbo-16k",
|
||||
stream=True,
|
||||
)
|
||||
@@ -172,16 +256,15 @@ async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
|
||||
return full_reply_content
|
||||
|
||||
|
||||
async def _achat_completion_json(messages: list[dict]) -> str:
|
||||
def _achat_completion_json(messages: list[dict] ) -> str:
|
||||
max_attempts = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = await async_client.chat.completions.create(
|
||||
response = async_client.chat.completions.create(
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
temperature=0.3,
|
||||
timeout=30,
|
||||
timeout=600,
|
||||
model=MODEL,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
@@ -245,7 +328,7 @@ def _cons_kwargs(messages: list[dict]) -> dict:
|
||||
"messages": messages,
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.3,
|
||||
"timeout": 60,
|
||||
"timeout": 600,
|
||||
}
|
||||
kwargs_mode = {"model": MODEL}
|
||||
kwargs.update(kwargs_mode)
|
||||
|
||||
@@ -81,16 +81,33 @@ class BaseAction():
|
||||
action_Record += PROMPT_TEMPLATE_ACTION_RECORD.format(AgentName = actionInfo["AgentName"], Action_Description = actionInfo["AgentName"], Action_Result = actionInfo["Action_Result"], Important_Mark = Important_Mark)
|
||||
|
||||
# Handle missing agent profiles gracefully
|
||||
model_config = None
|
||||
if agentName not in AgentProfile_Dict:
|
||||
print_colored(text=f"Warning: Agent '{agentName}' not found in AgentProfile_Dict. Using default profile.", text_color="yellow")
|
||||
agentProfile = f"AI Agent named {agentName}"
|
||||
else:
|
||||
agentProfile = AgentProfile_Dict[agentName]
|
||||
|
||||
prompt = PROMPT_TEMPLATE_TAKE_ACTION_BASE.format(agentName = agentName, agentProfile = agentProfile, General_Goal = General_Goal, Current_Task_Description = TaskDescription, Input_Objects = inputObject_Record, History_Action = action_Record, Action_Description = self.info["Description"], Action_Custom_Note = self.Action_Custom_Note)
|
||||
# agentProfile = AgentProfile_Dict[agentName]
|
||||
agent_config = AgentProfile_Dict[agentName]
|
||||
agentProfile = agent_config.get("profile",f"AI Agent named {agentName}")
|
||||
if agent_config.get("useCustomAPI",False):
|
||||
model_config = {
|
||||
"apiModel":agent_config.get("apiModel"),
|
||||
"apiUrl":agent_config.get("apiUrl"),
|
||||
"apiKey":agent_config.get("apiKey"),
|
||||
}
|
||||
prompt = PROMPT_TEMPLATE_TAKE_ACTION_BASE.format(
|
||||
agentName = agentName,
|
||||
agentProfile = agentProfile,
|
||||
General_Goal = General_Goal,
|
||||
Current_Task_Description = TaskDescription,
|
||||
Input_Objects = inputObject_Record,
|
||||
History_Action = action_Record,
|
||||
Action_Description = self.info["Description"],
|
||||
Action_Custom_Note = self.Action_Custom_Note
|
||||
)
|
||||
print_colored(text = prompt, text_color="red")
|
||||
messages = [{"role":"system", "content": prompt}]
|
||||
ActionResult = LLM_Completion(messages,True,False)
|
||||
ActionResult = LLM_Completion(messages,True,False,model_config=model_config)
|
||||
ActionInfo_with_Result = copy.deepcopy(self.info)
|
||||
ActionInfo_with_Result["Action_Result"] = ActionResult
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
## config for default LLM
|
||||
OPENAI_API_BASE: ""
|
||||
OPENAI_API_KEY: ""
|
||||
OPENAI_API_MODEL: "gpt-4-turbo-preview"
|
||||
OPENAI_API_BASE: "https://ai.gitee.com/v1"
|
||||
OPENAI_API_KEY: "HYCNGM39GGFNSB1F8MBBMI9QYJR3P1CRSYS2PV1A"
|
||||
OPENAI_API_MODEL: "DeepSeek-V3"
|
||||
|
||||
## config for fast mode
|
||||
FAST_DESIGN_MODE: True
|
||||
FAST_DESIGN_MODE: False
|
||||
GROQ_API_KEY: ""
|
||||
MISTRAL_API_KEY: ""
|
||||
|
||||
|
||||
@@ -271,13 +271,29 @@ def Handle_saveRequestCashe():
|
||||
|
||||
@app.route("/setAgents", methods=["POST"])
|
||||
def set_agents():
|
||||
global AgentBoard, AgentProfile_Dict
|
||||
global AgentBoard, AgentProfile_Dict,yaml_data
|
||||
AgentBoard = request.json
|
||||
AgentProfile_Dict = {}
|
||||
for item in AgentBoard:
|
||||
name = item["Name"]
|
||||
profile = item["Profile"]
|
||||
AgentProfile_Dict[name] = profile
|
||||
if all(item.get(field) for field in ["apiUrl","apiKey","apiModel"]):
|
||||
agent_config = {
|
||||
"profile": item["Profile"],
|
||||
"apiUrl": item["apiUrl"],
|
||||
"apiKey": item["apiKey"],
|
||||
"apiModel": item["apiModel"],
|
||||
"useCustomAPI":True
|
||||
}
|
||||
else:
|
||||
agent_config = {
|
||||
"profile": item["Profile"],
|
||||
"apiUrl": yaml_data.get("OPENAI_API_BASE"),
|
||||
"apiKey": yaml_data.get("OPENAI_API_KEY"),
|
||||
"apiModel": yaml_data.get("OPENAI_API_MODEL"),
|
||||
"useCustomAPI":False
|
||||
}
|
||||
AgentProfile_Dict[name] = agent_config
|
||||
|
||||
return jsonify({"code": 200, "content": "set agentboard successfully"})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user