400 lines
14 KiB
Python
400 lines
14 KiB
Python
import asyncio
|
|
|
|
import httpx
|
|
from openai import OpenAI, AsyncOpenAI, max_retries
|
|
import yaml
|
|
from termcolor import colored
|
|
import os
|
|
|
|
# Helper function to avoid circular import
|
|
def print_colored(text, text_color="green", background="on_white"):
|
|
print(colored(text, text_color, background))
|
|
|
|
# load config (apikey, apibase, model)
|
|
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
|
try:
|
|
with open(yaml_file, "r", encoding="utf-8") as file:
|
|
yaml_data = yaml.safe_load(file)
|
|
except Exception:
|
|
yaml_data = {}
|
|
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or yaml_data.get(
|
|
"OPENAI_API_BASE", "https://api.openai.com"
|
|
)
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or yaml_data.get(
|
|
"OPENAI_API_KEY", ""
|
|
)
|
|
OPENAI_API_MODEL = os.getenv("OPENAI_API_MODEL") or yaml_data.get(
|
|
"OPENAI_API_MODEL", ""
|
|
)
|
|
|
|
# Initialize OpenAI clients
|
|
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
|
async_client = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
|
MODEL: str = os.getenv("OPENAI_API_MODEL") or yaml_data.get(
|
|
"OPENAI_API_MODEL", "gpt-4-turbo-preview"
|
|
)
|
|
FAST_DESIGN_MODE: bool = os.getenv("FAST_DESIGN_MODE")
|
|
if FAST_DESIGN_MODE is None:
|
|
FAST_DESIGN_MODE = yaml_data.get("FAST_DESIGN_MODE", False)
|
|
else:
|
|
FAST_DESIGN_MODE = FAST_DESIGN_MODE.lower() in ["true", "1", "yes"]
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or yaml_data.get("GROQ_API_KEY", "")
|
|
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") or yaml_data.get(
|
|
"MISTRAL_API_KEY", ""
|
|
)
|
|
|
|
|
|
# for LLM completion
|
|
def LLM_Completion(
|
|
messages: list[dict], stream: bool = True, useGroq: bool = True,model_config: dict = None
|
|
) -> str:
|
|
if model_config:
|
|
print_colored(f"Using model config: {model_config}", "blue")
|
|
return _call_with_custom_config(messages,stream,model_config)
|
|
if not useGroq or not FAST_DESIGN_MODE:
|
|
force_gpt4 = True
|
|
useGroq = False
|
|
else:
|
|
force_gpt4 = False
|
|
useGroq = True
|
|
|
|
if stream:
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
except RuntimeError as ex:
|
|
if "There is no current event loop in thread" in str(ex):
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
if useGroq:
|
|
if force_gpt4:
|
|
return loop.run_until_complete(
|
|
_achat_completion_json(messages=messages)
|
|
)
|
|
else:
|
|
return loop.run_until_complete(
|
|
_achat_completion_stream_groq(messages=messages)
|
|
)
|
|
else:
|
|
return loop.run_until_complete(
|
|
_achat_completion_stream(messages=messages)
|
|
)
|
|
# return asyncio.run(_achat_completion_stream(messages = messages))
|
|
else:
|
|
return _chat_completion(messages=messages)
|
|
|
|
|
|
def _call_with_custom_config(messages: list[dict], stream: bool, model_config: dict) ->str:
|
|
"使用自定义配置调用API"
|
|
api_url = model_config.get("apiUrl", OPENAI_API_BASE)
|
|
api_key = model_config.get("apiKey", OPENAI_API_KEY)
|
|
api_model = model_config.get("apiModel", OPENAI_API_MODEL)
|
|
|
|
temp_client = OpenAI(api_key=api_key, base_url=api_url)
|
|
temp_async_client = AsyncOpenAI(api_key=api_key, base_url=api_url)
|
|
|
|
try:
|
|
if stream:
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
except RuntimeError as ex:
|
|
if "There is no current event loop in thread" in str(ex):
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
return loop.run_until_complete(
|
|
_achat_completion_stream_custom(messages=messages, temp_async_client=temp_async_client, api_model=api_model)
|
|
)
|
|
else:
|
|
response = temp_client.chat.completions.create(
|
|
messages=messages,
|
|
model=api_model,
|
|
temperature=0.3,
|
|
max_tokens=4096,
|
|
timeout=180
|
|
|
|
)
|
|
# 检查响应是否有效
|
|
if not response.choices or len(response.choices) == 0:
|
|
raise Exception(f"API returned empty response for model {api_model}")
|
|
if not response.choices[0] or not response.choices[0].message:
|
|
raise Exception(f"API returned invalid response format for model {api_model}")
|
|
|
|
full_reply_content = response.choices[0].message.content
|
|
if full_reply_content is None:
|
|
raise Exception(f"API returned None content for model {api_model}")
|
|
|
|
print(colored(full_reply_content, "blue", "on_white"), end="")
|
|
return full_reply_content
|
|
except Exception as e:
|
|
print_colored(f"Custom API error for model {api_model} :{str(e)}","red")
|
|
raise
|
|
|
|
|
|
async def _achat_completion_stream_custom(messages:list[dict], temp_async_client, api_model: str ) -> str:
|
|
max_retries=3
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = await temp_async_client.chat.completions.create(
|
|
messages=messages,
|
|
model=api_model,
|
|
temperature=0.3,
|
|
max_tokens=4096,
|
|
stream=True,
|
|
timeout=180
|
|
)
|
|
|
|
collected_chunks = []
|
|
collected_messages = []
|
|
async for chunk in response:
|
|
collected_chunks.append(chunk)
|
|
choices = chunk.choices
|
|
if len(choices) > 0 and choices[0] is not None:
|
|
chunk_message = choices[0].delta
|
|
if chunk_message is not None:
|
|
collected_messages.append(chunk_message)
|
|
if chunk_message.content:
|
|
print(colored(chunk_message.content, "blue", "on_white"), end="")
|
|
print()
|
|
full_reply_content = "".join(
|
|
[m.content or "" for m in collected_messages if m is not None]
|
|
)
|
|
|
|
# 检查最终结果是否为空
|
|
if not full_reply_content or full_reply_content.strip() == "":
|
|
raise Exception(f"Stream API returned empty content for model {api_model}")
|
|
|
|
return full_reply_content
|
|
except httpx.RemoteProtocolError as e:
|
|
if attempt < max_retries - 1:
|
|
wait_time = (attempt + 1) *2
|
|
print_colored(f"⚠️ Stream connection interrupted (attempt {attempt+1}/{max_retries}). Retrying in {wait_time}s...", text_color="yellow")
|
|
await asyncio.sleep(wait_time)
|
|
continue
|
|
except Exception as e:
|
|
print_colored(f"Custom API stream error for model {api_model} :{str(e)}","red")
|
|
raise
|
|
|
|
|
|
async def _achat_completion_stream_groq(messages: list[dict]) -> str:
|
|
from groq import AsyncGroq
|
|
groq_client = AsyncGroq(api_key=GROQ_API_KEY)
|
|
|
|
max_attempts = 5
|
|
|
|
for attempt in range(max_attempts):
|
|
print("Attempt to use Groq (Fase Design Mode):")
|
|
try:
|
|
response = await groq_client.chat.completions.create(
|
|
messages=messages,
|
|
# model='gemma-7b-it',
|
|
model="mixtral-8x7b-32768",
|
|
# model='llama2-70b-4096',
|
|
temperature=0.3,
|
|
response_format={"type": "json_object"},
|
|
stream=False,
|
|
)
|
|
break
|
|
except Exception:
|
|
if attempt < max_attempts - 1: # i is zero indexed
|
|
continue
|
|
else:
|
|
raise Exception("failed")
|
|
|
|
# 检查响应是否有效
|
|
if not response.choices or len(response.choices) == 0:
|
|
raise Exception("Groq API returned empty response")
|
|
if not response.choices[0] or not response.choices[0].message:
|
|
raise Exception("Groq API returned invalid response format")
|
|
|
|
full_reply_content = response.choices[0].message.content
|
|
if full_reply_content is None:
|
|
raise Exception("Groq API returned None content")
|
|
|
|
print(colored(full_reply_content, "blue", "on_white"), end="")
|
|
print()
|
|
return full_reply_content
|
|
|
|
|
|
async def _achat_completion_stream_mixtral(messages: list[dict]) -> str:
|
|
from mistralai.client import MistralClient
|
|
from mistralai.models.chat_completion import ChatMessage
|
|
mistral_client = MistralClient(api_key=MISTRAL_API_KEY)
|
|
# client=AsyncGroq(api_key=GROQ_API_KEY)
|
|
max_attempts = 5
|
|
|
|
for attempt in range(max_attempts):
|
|
try:
|
|
messages[len(messages) - 1]["role"] = "user"
|
|
stream = mistral_client.chat(
|
|
messages=[
|
|
ChatMessage(
|
|
role=message["role"], content=message["content"]
|
|
)
|
|
for message in messages
|
|
],
|
|
# model = "mistral-small-latest",
|
|
model="open-mixtral-8x7b",
|
|
)
|
|
break # If the operation is successful, break the loop
|
|
except Exception:
|
|
if attempt < max_attempts - 1: # i is zero indexed
|
|
continue
|
|
else:
|
|
raise Exception("failed")
|
|
|
|
# 检查响应是否有效
|
|
if not stream.choices or len(stream.choices) == 0:
|
|
raise Exception("Mistral API returned empty response")
|
|
if not stream.choices[0] or not stream.choices[0].message:
|
|
raise Exception("Mistral API returned invalid response format")
|
|
|
|
full_reply_content = stream.choices[0].message.content
|
|
if full_reply_content is None:
|
|
raise Exception("Mistral API returned None content")
|
|
|
|
print(colored(full_reply_content, "blue", "on_white"), end="")
|
|
print()
|
|
return full_reply_content
|
|
|
|
|
|
async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
|
|
response = await async_client.chat.completions.create(
|
|
messages=messages,
|
|
max_tokens=4096,
|
|
temperature=0.3,
|
|
timeout=600,
|
|
model="gpt-3.5-turbo-16k",
|
|
stream=True,
|
|
)
|
|
|
|
# create variables to collect the stream of chunks
|
|
collected_chunks = []
|
|
collected_messages = []
|
|
# iterate through the stream of events
|
|
async for chunk in response:
|
|
collected_chunks.append(chunk) # save the event response
|
|
choices = chunk.choices
|
|
if len(choices) > 0 and choices[0] is not None:
|
|
chunk_message = choices[0].delta
|
|
if chunk_message is not None:
|
|
collected_messages.append(chunk_message) # save the message
|
|
if chunk_message.content:
|
|
print(
|
|
colored(chunk_message.content, "blue", "on_white"),
|
|
end="",
|
|
)
|
|
print()
|
|
|
|
full_reply_content = "".join(
|
|
[m.content or "" for m in collected_messages if m is not None]
|
|
)
|
|
|
|
# 检查最终结果是否为空
|
|
if not full_reply_content or full_reply_content.strip() == "":
|
|
raise Exception("Stream API (gpt-3.5) returned empty content")
|
|
|
|
return full_reply_content
|
|
|
|
|
|
def _achat_completion_json(messages: list[dict] ) -> str:
|
|
max_attempts = 5
|
|
for attempt in range(max_attempts):
|
|
try:
|
|
response = async_client.chat.completions.create(
|
|
messages=messages,
|
|
max_tokens=4096,
|
|
temperature=0.3,
|
|
timeout=600,
|
|
model=MODEL,
|
|
response_format={"type": "json_object"},
|
|
)
|
|
break
|
|
except Exception:
|
|
if attempt < max_attempts - 1: # i is zero indexed
|
|
continue
|
|
else:
|
|
raise Exception("failed")
|
|
|
|
# 检查响应是否有效
|
|
if not response.choices or len(response.choices) == 0:
|
|
raise Exception("OpenAI API returned empty response")
|
|
if not response.choices[0] or not response.choices[0].message:
|
|
raise Exception("OpenAI API returned invalid response format")
|
|
|
|
full_reply_content = response.choices[0].message.content
|
|
if full_reply_content is None:
|
|
raise Exception("OpenAI API returned None content")
|
|
|
|
print(colored(full_reply_content, "blue", "on_white"), end="")
|
|
print()
|
|
return full_reply_content
|
|
|
|
|
|
async def _achat_completion_stream(messages: list[dict]) -> str:
|
|
try:
|
|
response = await async_client.chat.completions.create(
|
|
**_cons_kwargs(messages), stream=True
|
|
)
|
|
|
|
# create variables to collect the stream of chunks
|
|
collected_chunks = []
|
|
collected_messages = []
|
|
# iterate through the stream of events
|
|
async for chunk in response:
|
|
collected_chunks.append(chunk) # save the event response
|
|
choices = chunk.choices
|
|
if len(choices) > 0 and choices[0] is not None:
|
|
chunk_message = choices[0].delta
|
|
if chunk_message is not None:
|
|
collected_messages.append(chunk_message) # save the message
|
|
if chunk_message.content:
|
|
print(
|
|
colored(chunk_message.content, "blue", "on_white"),
|
|
end="",
|
|
)
|
|
print()
|
|
|
|
full_reply_content = "".join(
|
|
[m.content or "" for m in collected_messages if m is not None]
|
|
)
|
|
|
|
# 检查最终结果是否为空
|
|
if not full_reply_content or full_reply_content.strip() == "":
|
|
raise Exception("Stream API returned empty content")
|
|
|
|
return full_reply_content
|
|
except Exception as e:
|
|
print_colored(f"OpenAI API error in _achat_completion_stream: {str(e)}", "red")
|
|
raise
|
|
|
|
|
|
def _chat_completion(messages: list[dict]) -> str:
|
|
try:
|
|
rsp = client.chat.completions.create(**_cons_kwargs(messages))
|
|
|
|
# 检查响应是否有效
|
|
if not rsp.choices or len(rsp.choices) == 0:
|
|
raise Exception("OpenAI API returned empty response")
|
|
if not rsp.choices[0] or not rsp.choices[0].message:
|
|
raise Exception("OpenAI API returned invalid response format")
|
|
|
|
content = rsp.choices[0].message.content
|
|
if content is None:
|
|
raise Exception("OpenAI API returned None content")
|
|
|
|
return content
|
|
except Exception as e:
|
|
print_colored(f"OpenAI API error in _chat_completion: {str(e)}", "red")
|
|
raise
|
|
|
|
|
|
def _cons_kwargs(messages: list[dict]) -> dict:
|
|
kwargs = {
|
|
"messages": messages,
|
|
"max_tokens": 2000,
|
|
"temperature": 0.3,
|
|
"timeout": 600,
|
|
}
|
|
kwargs_mode = {"model": MODEL}
|
|
kwargs.update(kwargs_mode)
|
|
return kwargs
|