feat:1.数据库存储功能添加(初版)2.后端REST API版本代码清理
This commit is contained in:
@@ -49,7 +49,6 @@ def LLM_Completion(
|
||||
messages: list[dict], stream: bool = True, useGroq: bool = True,model_config: dict = None
|
||||
) -> str:
|
||||
if model_config:
|
||||
print_colored(f"Using model config: {model_config}", "blue")
|
||||
return _call_with_custom_config(messages,stream,model_config)
|
||||
if not useGroq or not FAST_DESIGN_MODE:
|
||||
force_gpt4 = True
|
||||
@@ -125,7 +124,7 @@ def _call_with_custom_config(messages: list[dict], stream: bool, model_config: d
|
||||
#print(colored(full_reply_content, "blue", "on_white"), end="")
|
||||
return full_reply_content
|
||||
except Exception as e:
|
||||
print_colored(f"Custom API error for model {api_model} :{str(e)}","red")
|
||||
print_colored(f"[API Error] Custom API error: {str(e)}", "red")
|
||||
raise
|
||||
|
||||
|
||||
@@ -166,11 +165,11 @@ async def _achat_completion_stream_custom(messages:list[dict], temp_async_client
|
||||
except httpx.RemoteProtocolError as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = (attempt + 1) *2
|
||||
print_colored(f"⚠️ Stream connection interrupted (attempt {attempt+1}/{max_retries}). Retrying in {wait_time}s...", text_color="yellow")
|
||||
print_colored(f"[API Warn] Stream connection interrupted, retrying in {wait_time}s...", "yellow")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
except Exception as e:
|
||||
print_colored(f"Custom API stream error for model {api_model} :{str(e)}","red")
|
||||
print_colored(f"[API Error] Custom API stream error: {str(e)}", "red")
|
||||
raise
|
||||
|
||||
|
||||
@@ -181,7 +180,6 @@ async def _achat_completion_stream_groq(messages: list[dict]) -> str:
|
||||
max_attempts = 5
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
print("Attempt to use Groq (Fase Design Mode):")
|
||||
try:
|
||||
response = await groq_client.chat.completions.create(
|
||||
messages=messages,
|
||||
@@ -363,7 +361,7 @@ async def _achat_completion_stream(messages: list[dict]) -> str:
|
||||
|
||||
return full_reply_content
|
||||
except Exception as e:
|
||||
print_colored(f"OpenAI API error in _achat_completion_stream: {str(e)}", "red")
|
||||
print_colored(f"[API Error] OpenAI API stream error: {str(e)}", "red")
|
||||
raise
|
||||
|
||||
|
||||
@@ -383,7 +381,7 @@ def _chat_completion(messages: list[dict]) -> str:
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
print_colored(f"OpenAI API error in _chat_completion: {str(e)}", "red")
|
||||
print_colored(f"[API Error] OpenAI API error: {str(e)}", "red")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ def generate_AbilityRequirement(General_Goal, Current_Task):
|
||||
),
|
||||
},
|
||||
]
|
||||
print(messages[1]["content"])
|
||||
return read_LLM_Completion(messages)["AbilityRequirement"]
|
||||
|
||||
|
||||
@@ -106,6 +105,7 @@ def generate_AgentSelection(General_Goal, Current_Task, Agent_Board):
|
||||
|
||||
while True:
|
||||
candidate = read_LLM_Completion(messages)["AgentSelectionPlan"]
|
||||
# 添加调试打印
|
||||
if len(candidate) > MAX_TEAM_SIZE:
|
||||
teamSize = random.randint(2, MAX_TEAM_SIZE)
|
||||
candidate = candidate[0:teamSize]
|
||||
@@ -113,7 +113,6 @@ def generate_AgentSelection(General_Goal, Current_Task, Agent_Board):
|
||||
continue
|
||||
AgentSelectionPlan = sorted(candidate)
|
||||
AgentSelectionPlan_set = set(AgentSelectionPlan)
|
||||
|
||||
# Check if every item in AgentSelectionPlan is in agentboard
|
||||
if AgentSelectionPlan_set.issubset(agentboard_set):
|
||||
break # If all items are in agentboard, break the loop
|
||||
|
||||
@@ -85,7 +85,6 @@ class BaseAction():
|
||||
# Handle missing agent profiles gracefully
|
||||
model_config = None
|
||||
if agentName not in AgentProfile_Dict:
|
||||
print_colored(text=f"Warning: Agent '{agentName}' not found in AgentProfile_Dict. Using default profile.", text_color="yellow")
|
||||
agentProfile = f"AI Agent named {agentName}"
|
||||
else:
|
||||
# agentProfile = AgentProfile_Dict[agentName]
|
||||
|
||||
@@ -83,7 +83,6 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
|
||||
}
|
||||
|
||||
# start the group chat
|
||||
util.print_colored(TaskDescription, text_color="green")
|
||||
ActionHistory = []
|
||||
action_count = 0
|
||||
total_actions = len(TaskProcess)
|
||||
@@ -93,9 +92,6 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
|
||||
actionType = ActionInfo["ActionType"]
|
||||
agentName = ActionInfo["AgentName"]
|
||||
|
||||
# 添加进度日志
|
||||
util.print_colored(f"🔄 Executing action {action_count}/{total_actions}: {actionType} by {agentName}", text_color="yellow")
|
||||
|
||||
if actionType in Action.customAction_Dict:
|
||||
currentAction = Action.customAction_Dict[actionType](
|
||||
info=ActionInfo,
|
||||
@@ -126,11 +122,4 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
|
||||
stepLogNode["ActionHistory"] = ActionHistory
|
||||
|
||||
# Return Output
|
||||
print(
|
||||
colored(
|
||||
"$Run " + str(StepRun_count) + "step$",
|
||||
color="black",
|
||||
on_color="on_white",
|
||||
)
|
||||
)
|
||||
return RehearsalLog
|
||||
|
||||
@@ -107,7 +107,6 @@ def get_parallel_batches(TaskProcess: List[Dict], dependency_map: Dict[int, List
|
||||
# 避免死循环
|
||||
remaining = [i for i in range(len(TaskProcess)) if i not in completed]
|
||||
if remaining:
|
||||
print(colored(f"警告: 检测到循环依赖,强制串行执行: {remaining}", "yellow"))
|
||||
ready_to_run = remaining[:1]
|
||||
else:
|
||||
break
|
||||
@@ -188,7 +187,8 @@ async def execute_step_async_streaming(
|
||||
KeyObjects: Dict,
|
||||
step_index: int,
|
||||
total_steps: int,
|
||||
execution_id: str = None
|
||||
execution_id: str = None,
|
||||
RehearsalLog: List = None # 用于追加日志到历史记录
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""
|
||||
异步执行单个步骤,支持流式返回
|
||||
@@ -276,8 +276,9 @@ async def execute_step_async_streaming(
|
||||
total_actions = len(TaskProcess)
|
||||
completed_actions = 0
|
||||
|
||||
# 步骤开始日志
|
||||
util.print_colored(
|
||||
f"📋 步骤 {step_index + 1}/{total_steps}: {StepName} ({total_actions} 个动作, 分 {len(batches)} 批并行执行)",
|
||||
f"📋 步骤 {step_index + 1}/{total_steps}: {StepName} ({total_actions} 个动作, 分 {len(batches)} 批执行)",
|
||||
text_color="cyan"
|
||||
)
|
||||
|
||||
@@ -286,11 +287,11 @@ async def execute_step_async_streaming(
|
||||
# 在每个批次执行前检查暂停状态
|
||||
should_continue = await execution_state_manager.async_check_pause(execution_id)
|
||||
if not should_continue:
|
||||
util.print_colored("🛑 用户请求停止执行", "red")
|
||||
return
|
||||
|
||||
batch_size = len(batch_indices)
|
||||
|
||||
# 批次执行日志
|
||||
if batch_size > 1:
|
||||
util.print_colored(
|
||||
f"🚦 批次 {batch_index + 1}/{len(batches)}: 并行执行 {batch_size} 个动作",
|
||||
@@ -298,7 +299,7 @@ async def execute_step_async_streaming(
|
||||
)
|
||||
else:
|
||||
util.print_colored(
|
||||
f"🔄 动作 {completed_actions + 1}/{total_actions}: 串行执行",
|
||||
f"🔄 批次 {batch_index + 1}/{len(batches)}: 串行执行",
|
||||
text_color="yellow"
|
||||
)
|
||||
|
||||
@@ -353,12 +354,21 @@ async def execute_step_async_streaming(
|
||||
objectLogNode["content"] = KeyObjects[OutputName]
|
||||
stepLogNode["ActionHistory"] = ActionHistory
|
||||
|
||||
# 收集该步骤使用的 agent(去重)
|
||||
assigned_agents_in_step = list(set(Agent_List)) if Agent_List else []
|
||||
|
||||
# 追加到 RehearsalLog(因为 RehearsalLog 是可变对象,会反映到原列表)
|
||||
if RehearsalLog is not None:
|
||||
RehearsalLog.append(stepLogNode)
|
||||
RehearsalLog.append(objectLogNode)
|
||||
|
||||
yield {
|
||||
"type": "step_complete",
|
||||
"step_index": step_index,
|
||||
"step_name": StepName,
|
||||
"step_log_node": stepLogNode,
|
||||
"object_log_node": objectLogNode,
|
||||
"assigned_agents": {StepName: assigned_agents_in_step}, # 该步骤使用的 agent
|
||||
}
|
||||
|
||||
|
||||
@@ -394,8 +404,6 @@ def executePlan_streaming_dynamic(
|
||||
|
||||
execution_state_manager.start_execution(execution_id, general_goal)
|
||||
|
||||
print(colored(f"⏸️ 执行状态管理器已启动,支持暂停/恢复,execution_id={execution_id}", "green"))
|
||||
|
||||
# 准备执行
|
||||
KeyObjects = existingKeyObjects.copy() if existingKeyObjects else {}
|
||||
finishedStep_index = -1
|
||||
@@ -406,9 +414,6 @@ def executePlan_streaming_dynamic(
|
||||
if logNode["LogNodeType"] == "object":
|
||||
KeyObjects[logNode["NodeId"]] = logNode["content"]
|
||||
|
||||
if existingKeyObjects:
|
||||
print(colored(f"📦 使用已存在的 KeyObjects: {list(existingKeyObjects.keys())}", "cyan"))
|
||||
|
||||
# 确定要运行的步骤范围
|
||||
if num_StepToRun is None:
|
||||
run_to = len(plan["Collaboration Process"])
|
||||
@@ -421,9 +426,6 @@ def executePlan_streaming_dynamic(
|
||||
if execution_id:
|
||||
# 初始化执行管理器,使用传入的execution_id
|
||||
actual_execution_id = dynamic_execution_manager.start_execution(general_goal, steps_to_run, execution_id)
|
||||
print(colored(f"🚀 开始执行计划(动态模式),共 {len(steps_to_run)} 个步骤,执行ID: {actual_execution_id}", "cyan"))
|
||||
else:
|
||||
print(colored(f"🚀 开始执行计划(流式推送),共 {len(steps_to_run)} 个步骤", "cyan"))
|
||||
|
||||
total_steps = len(steps_to_run)
|
||||
|
||||
@@ -443,7 +445,7 @@ def executePlan_streaming_dynamic(
|
||||
# 检查暂停状态
|
||||
should_continue = await execution_state_manager.async_check_pause(execution_id)
|
||||
if not should_continue:
|
||||
print(colored("🛑 用户请求停止执行", "red"))
|
||||
util.print_colored("🛑 用户请求停止执行", "red")
|
||||
await queue.put({
|
||||
"type": "error",
|
||||
"message": "执行已被用户停止"
|
||||
@@ -466,29 +468,24 @@ def executePlan_streaming_dynamic(
|
||||
|
||||
# 如果没有步骤在队列中(queue_total_steps为0),立即退出
|
||||
if queue_total_steps == 0:
|
||||
print(colored(f"⚠️ 没有步骤在队列中,退出执行", "yellow"))
|
||||
break
|
||||
|
||||
# 如果所有步骤都已完成,等待可能的新步骤
|
||||
if completed_steps >= queue_total_steps:
|
||||
if empty_wait_count >= max_empty_wait_cycles:
|
||||
# 等待超时,退出执行
|
||||
print(colored(f"✅ 所有步骤执行完成,等待超时", "green"))
|
||||
break
|
||||
else:
|
||||
# 等待新步骤追加
|
||||
print(colored(f"⏳ 等待新步骤追加... ({empty_wait_count}/{max_empty_wait_cycles})", "cyan"))
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
# 还有步骤未完成,继续尝试获取
|
||||
print(colored(f"⏳ 等待步骤就绪... ({completed_steps}/{queue_total_steps})", "cyan"))
|
||||
await asyncio.sleep(0.5)
|
||||
empty_wait_count = 0 # 重置等待计数
|
||||
continue
|
||||
else:
|
||||
# 执行信息不存在,退出
|
||||
print(colored(f"⚠️ 执行信息不存在,退出执行", "yellow"))
|
||||
break
|
||||
|
||||
# 重置等待计数
|
||||
@@ -506,7 +503,8 @@ def executePlan_streaming_dynamic(
|
||||
KeyObjects,
|
||||
step_index,
|
||||
current_total_steps, # 使用动态更新的总步骤数
|
||||
execution_id
|
||||
execution_id,
|
||||
RehearsalLog # 传递 RehearsalLog 用于追加日志
|
||||
):
|
||||
if execution_state_manager.is_stopped(execution_id):
|
||||
await queue.put({
|
||||
@@ -533,7 +531,7 @@ def executePlan_streaming_dynamic(
|
||||
for step_index, stepDescrip in enumerate(steps_to_run):
|
||||
should_continue = await execution_state_manager.async_check_pause(execution_id)
|
||||
if not should_continue:
|
||||
print(colored("🛑 用户请求停止执行", "red"))
|
||||
util.print_colored("🛑 用户请求停止执行", "red")
|
||||
await queue.put({
|
||||
"type": "error",
|
||||
"message": "执行已被用户停止"
|
||||
@@ -547,7 +545,8 @@ def executePlan_streaming_dynamic(
|
||||
KeyObjects,
|
||||
step_index,
|
||||
total_steps,
|
||||
execution_id
|
||||
execution_id,
|
||||
RehearsalLog # 传递 RehearsalLog 用于追加日志
|
||||
):
|
||||
if execution_state_manager.is_stopped(execution_id):
|
||||
await queue.put({
|
||||
|
||||
@@ -63,10 +63,7 @@ class DynamicExecutionManager:
|
||||
# 初始化待执行步骤索引
|
||||
self._pending_steps[execution_id] = list(range(len(initial_steps)))
|
||||
|
||||
print(f"🚀 启动执行: {execution_id}")
|
||||
print(f"📊 初始步骤数: {len(initial_steps)}")
|
||||
print(f"📋 待执行步骤索引: {self._pending_steps[execution_id]}")
|
||||
|
||||
print(f"[Execution] 启动执行: {execution_id}, 步骤数={len(initial_steps)}")
|
||||
return execution_id
|
||||
|
||||
def add_steps(self, execution_id: str, new_steps: List[Dict]) -> int:
|
||||
@@ -82,7 +79,6 @@ class DynamicExecutionManager:
|
||||
"""
|
||||
with self._lock:
|
||||
if execution_id not in self._step_queues:
|
||||
print(f"⚠️ 警告: 执行ID {execution_id} 不存在,无法追加步骤")
|
||||
return 0
|
||||
|
||||
current_count = len(self._step_queues[execution_id])
|
||||
@@ -95,14 +91,9 @@ class DynamicExecutionManager:
|
||||
self._pending_steps[execution_id].extend(new_indices)
|
||||
|
||||
# 更新总步骤数
|
||||
old_total = self._executions[execution_id]["total_steps"]
|
||||
self._executions[execution_id]["total_steps"] = len(self._step_queues[execution_id])
|
||||
new_total = self._executions[execution_id]["total_steps"]
|
||||
|
||||
print(f"➕ 追加了 {len(new_steps)} 个步骤到 {execution_id}")
|
||||
print(f"📊 步骤总数: {old_total} -> {new_total}")
|
||||
print(f"📋 待执行步骤索引: {self._pending_steps[execution_id]}")
|
||||
|
||||
print(f"[Execution] 追加步骤: +{len(new_steps)}, 总计={self._executions[execution_id]['total_steps']}")
|
||||
return len(new_steps)
|
||||
|
||||
def get_next_step(self, execution_id: str) -> Optional[Dict]:
|
||||
@@ -117,7 +108,6 @@ class DynamicExecutionManager:
|
||||
"""
|
||||
with self._lock:
|
||||
if execution_id not in self._pending_steps:
|
||||
print(f"⚠️ 警告: 执行ID {execution_id} 不存在")
|
||||
return None
|
||||
|
||||
# 获取第一个待执行步骤的索引
|
||||
@@ -128,7 +118,6 @@ class DynamicExecutionManager:
|
||||
|
||||
# 从队列中获取步骤
|
||||
if step_index >= len(self._step_queues[execution_id]):
|
||||
print(f"⚠️ 警告: 步骤索引 {step_index} 超出范围")
|
||||
return None
|
||||
|
||||
step = self._step_queues[execution_id][step_index]
|
||||
@@ -137,9 +126,7 @@ class DynamicExecutionManager:
|
||||
self._executed_steps[execution_id].add(step_index)
|
||||
|
||||
step_name = step.get("StepName", "未知")
|
||||
print(f"🎯 获取下一个步骤: {step_name} (索引: {step_index})")
|
||||
print(f"📋 剩余待执行步骤: {len(self._pending_steps[execution_id])}")
|
||||
|
||||
print(f"[Execution] 获取步骤: {step_name} (索引: {step_index})")
|
||||
return step
|
||||
|
||||
def mark_step_completed(self, execution_id: str):
|
||||
@@ -154,9 +141,7 @@ class DynamicExecutionManager:
|
||||
self._executions[execution_id]["completed_steps"] += 1
|
||||
completed = self._executions[execution_id]["completed_steps"]
|
||||
total = self._executions[execution_id]["total_steps"]
|
||||
print(f"📊 步骤完成进度: {completed}/{total}")
|
||||
else:
|
||||
print(f"⚠️ 警告: 执行ID {execution_id} 不存在")
|
||||
print(f"[Execution] 步骤完成: {completed}/{total}")
|
||||
|
||||
def get_execution_info(self, execution_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
|
||||
@@ -129,7 +129,6 @@ class ExecutionStateManager:
|
||||
state['goal'] = goal
|
||||
state['should_pause'] = False
|
||||
state['should_stop'] = False
|
||||
print(f"🚀 [DEBUG] start_execution: execution_id={execution_id}, 状态设置为 RUNNING, goal={goal}")
|
||||
|
||||
def pause_execution(self, execution_id: str) -> bool:
|
||||
"""
|
||||
@@ -143,19 +142,13 @@ class ExecutionStateManager:
|
||||
"""
|
||||
state = self._get_state(execution_id)
|
||||
if state is None:
|
||||
# 打印当前所有活跃的 execution_id,帮助调试
|
||||
active_ids = list(self._states.keys())
|
||||
print(f"⚠️ [DEBUG] pause_execution: execution_id={execution_id} 不存在")
|
||||
print(f" 当前活跃的 execution_id 列表: {active_ids}")
|
||||
return False
|
||||
|
||||
with self._get_lock(execution_id):
|
||||
if state['status'] != ExecutionStatus.RUNNING:
|
||||
print(f"⚠️ [DEBUG] pause_execution: execution_id={execution_id}, 当前状态是 {state['status']},无法暂停")
|
||||
return False
|
||||
state['status'] = ExecutionStatus.PAUSED
|
||||
state['should_pause'] = True
|
||||
print(f"⏸️ [DEBUG] pause_execution: execution_id={execution_id}, 状态设置为PAUSED")
|
||||
return True
|
||||
|
||||
def resume_execution(self, execution_id: str) -> bool:
|
||||
@@ -170,16 +163,13 @@ class ExecutionStateManager:
|
||||
"""
|
||||
state = self._get_state(execution_id)
|
||||
if state is None:
|
||||
print(f"⚠️ [DEBUG] resume_execution: execution_id={execution_id} 不存在")
|
||||
return False
|
||||
|
||||
with self._get_lock(execution_id):
|
||||
if state['status'] != ExecutionStatus.PAUSED:
|
||||
print(f"⚠️ [DEBUG] resume_execution: 当前状态不是PAUSED,而是 {state['status']}")
|
||||
return False
|
||||
state['status'] = ExecutionStatus.RUNNING
|
||||
state['should_pause'] = False
|
||||
print(f"▶️ [DEBUG] resume_execution: execution_id={execution_id}, 状态设置为RUNNING, should_pause=False")
|
||||
return True
|
||||
|
||||
def stop_execution(self, execution_id: str) -> bool:
|
||||
@@ -194,17 +184,14 @@ class ExecutionStateManager:
|
||||
"""
|
||||
state = self._get_state(execution_id)
|
||||
if state is None:
|
||||
print(f"⚠️ [DEBUG] stop_execution: execution_id={execution_id} 不存在")
|
||||
return False
|
||||
|
||||
with self._get_lock(execution_id):
|
||||
if state['status'] in [ExecutionStatus.IDLE, ExecutionStatus.STOPPED]:
|
||||
print(f"⚠️ [DEBUG] stop_execution: 当前状态是 {state['status']}, 无法停止")
|
||||
return False
|
||||
state['status'] = ExecutionStatus.STOPPED
|
||||
state['should_stop'] = True
|
||||
state['should_pause'] = False
|
||||
print(f"🛑 [DEBUG] stop_execution: execution_id={execution_id}, 状态设置为STOPPED")
|
||||
return True
|
||||
|
||||
def reset(self, execution_id: str):
|
||||
@@ -215,12 +202,10 @@ class ExecutionStateManager:
|
||||
state['goal'] = None
|
||||
state['should_pause'] = False
|
||||
state['should_stop'] = False
|
||||
print(f"🔄 [DEBUG] reset: execution_id={execution_id}, 状态重置为IDLE")
|
||||
|
||||
def cleanup(self, execution_id: str):
|
||||
"""清理指定 execution_id 的所有状态"""
|
||||
self._cleanup_state(execution_id)
|
||||
print(f"🧹 [DEBUG] cleanup: execution_id={execution_id} 的状态已清理")
|
||||
|
||||
async def async_check_pause(self, execution_id: str):
|
||||
"""
|
||||
@@ -248,7 +233,6 @@ class ExecutionStateManager:
|
||||
|
||||
# 检查停止标志
|
||||
if should_stop:
|
||||
print("🛑 [DEBUG] async_check_pause: execution_id={}, 检测到停止信号".format(execution_id))
|
||||
return False
|
||||
|
||||
# 检查暂停状态
|
||||
@@ -262,7 +246,6 @@ class ExecutionStateManager:
|
||||
should_stop = state['should_stop']
|
||||
|
||||
if not should_pause:
|
||||
print("▶️ [DEBUG] async_check_pause: execution_id={}, 从暂停中恢复!".format(execution_id))
|
||||
continue
|
||||
if should_stop:
|
||||
return False
|
||||
|
||||
@@ -113,7 +113,6 @@ class GenerationStateManager:
|
||||
state['status'] = GenerationStatus.GENERATING
|
||||
state['goal'] = goal
|
||||
state['should_stop'] = False
|
||||
print(f"🚀 [GenerationState] start_generation: generation_id={generation_id}, 状态设置为 GENERATING")
|
||||
|
||||
def stop_generation(self, generation_id: str) -> bool:
|
||||
"""
|
||||
@@ -127,26 +126,21 @@ class GenerationStateManager:
|
||||
"""
|
||||
state = self._get_state(generation_id)
|
||||
if state is None:
|
||||
print(f"⚠️ [GenerationState] stop_generation: generation_id={generation_id} 不存在")
|
||||
return True # 不存在也算停止成功
|
||||
|
||||
with self._get_lock(generation_id):
|
||||
if state['status'] == GenerationStatus.STOPPED:
|
||||
print(f"✅ [GenerationState] stop_generation: generation_id={generation_id} 已经是 STOPPED 状态")
|
||||
return True # 已经停止也算成功
|
||||
|
||||
if state['status'] == GenerationStatus.COMPLETED:
|
||||
print(f"✅ [GenerationState] stop_generation: generation_id={generation_id} 已经 COMPLETED,视为停止成功")
|
||||
return True # 已完成也视为停止成功
|
||||
|
||||
if state['status'] == GenerationStatus.IDLE:
|
||||
print(f"⚠️ [GenerationState] stop_generation: generation_id={generation_id} 是 IDLE 状态,无需停止")
|
||||
return True # 空闲状态也视为无需停止
|
||||
|
||||
# 真正需要停止的情况
|
||||
state['status'] = GenerationStatus.STOPPED
|
||||
state['should_stop'] = True
|
||||
print(f"🛑 [GenerationState] stop_generation: generation_id={generation_id}, 状态设置为STOPPED")
|
||||
return True
|
||||
|
||||
def complete_generation(self, generation_id: str):
|
||||
@@ -154,12 +148,10 @@ class GenerationStateManager:
|
||||
state = self._ensure_state(generation_id)
|
||||
with self._get_lock(generation_id):
|
||||
state['status'] = GenerationStatus.COMPLETED
|
||||
print(f"✅ [GenerationState] complete_generation: generation_id={generation_id}")
|
||||
|
||||
def cleanup(self, generation_id: str):
|
||||
"""清理指定 generation_id 的所有状态"""
|
||||
self._cleanup_state(generation_id)
|
||||
print(f"🧹 [GenerationState] cleanup: generation_id={generation_id} 的状态已清理")
|
||||
|
||||
def should_stop(self, generation_id: str) -> bool:
|
||||
"""检查是否应该停止"""
|
||||
|
||||
@@ -94,6 +94,7 @@ def remove_render_spec(duty_spec):
|
||||
return duty_spec
|
||||
|
||||
|
||||
|
||||
def read_LLM_Completion(messages, useGroq=True):
|
||||
for _ in range(3):
|
||||
text = LLM_Completion(messages, useGroq=useGroq)
|
||||
|
||||
25
backend/db/__init__.py
Normal file
25
backend/db/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
AgentCoord 数据库模块
|
||||
提供 PostgreSQL 数据库连接、模型和 CRUD 操作
|
||||
基于 DATABASE_DESIGN.md 设计
|
||||
"""
|
||||
|
||||
from .database import get_db, get_db_context, test_connection, engine, text
|
||||
from .models import MultiAgentTask, UserAgent, TaskStatus
|
||||
from .crud import MultiAgentTaskCRUD, UserAgentCRUD
|
||||
|
||||
__all__ = [
|
||||
# 连接管理
|
||||
"get_db",
|
||||
"get_db_context",
|
||||
"test_connection",
|
||||
"engine",
|
||||
"text",
|
||||
# 模型
|
||||
"MultiAgentTask",
|
||||
"UserAgent",
|
||||
"TaskStatus",
|
||||
# CRUD
|
||||
"MultiAgentTaskCRUD",
|
||||
"UserAgentCRUD",
|
||||
]
|
||||
404
backend/db/crud.py
Normal file
404
backend/db/crud.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""
|
||||
数据库 CRUD 操作
|
||||
封装所有数据库操作方法 (基于 DATABASE_DESIGN.md)
|
||||
"""
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import MultiAgentTask, UserAgent
|
||||
|
||||
|
||||
class MultiAgentTaskCRUD:
|
||||
"""多智能体任务 CRUD 操作"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
db: Session,
|
||||
task_id: Optional[str] = None, # 可选,如果为 None 则自动生成
|
||||
user_id: str = "",
|
||||
query: str = "",
|
||||
agents_info: list = [],
|
||||
task_outline: Optional[dict] = None,
|
||||
assigned_agents: Optional[list] = None,
|
||||
agent_scores: Optional[dict] = None,
|
||||
result: Optional[str] = None,
|
||||
) -> MultiAgentTask:
|
||||
"""创建任务记录"""
|
||||
task = MultiAgentTask(
|
||||
task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的
|
||||
user_id=user_id,
|
||||
query=query,
|
||||
agents_info=agents_info,
|
||||
task_outline=task_outline,
|
||||
assigned_agents=assigned_agents,
|
||||
agent_scores=agent_scores,
|
||||
result=result,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]:
|
||||
"""根据任务 ID 获取记录"""
|
||||
return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_by_user_id(
|
||||
db: Session, user_id: str, limit: int = 50, offset: int = 0
|
||||
) -> List[MultiAgentTask]:
|
||||
"""根据用户 ID 获取任务记录"""
|
||||
return (
|
||||
db.query(MultiAgentTask)
|
||||
.filter(MultiAgentTask.user_id == user_id)
|
||||
.order_by(MultiAgentTask.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_recent(
|
||||
db: Session, limit: int = 20, offset: int = 0
|
||||
) -> List[MultiAgentTask]:
|
||||
"""获取最近的任务记录"""
|
||||
return (
|
||||
db.query(MultiAgentTask)
|
||||
.order_by(MultiAgentTask.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_result(
|
||||
db: Session, task_id: str, result: list
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新任务结果"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.result = result if result else []
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_task_outline(
|
||||
db: Session, task_id: str, task_outline: dict
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新任务大纲"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.task_outline = task_outline
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_assigned_agents(
|
||||
db: Session, task_id: str, assigned_agents: dict
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新分配的智能体(步骤名 -> agent列表)"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.assigned_agents = assigned_agents
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_agent_scores(
|
||||
db: Session, task_id: str, agent_scores: dict
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新智能体评分(合并模式,追加新步骤的评分)"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
# 合并现有评分数据和新评分数据
|
||||
existing_scores = task.agent_scores or {}
|
||||
merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据
|
||||
task.agent_scores = merged_scores
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_status(
|
||||
db: Session, task_id: str, status: str
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新任务状态"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.status = status
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]:
|
||||
"""增加任务执行次数"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.execution_count = (task.execution_count or 0) + 1
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_generation_id(
|
||||
db: Session, task_id: str, generation_id: str
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新生成 ID"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.generation_id = generation_id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_execution_id(
|
||||
db: Session, task_id: str, execution_id: str
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新执行 ID"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.execution_id = execution_id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_rehearsal_log(
|
||||
db: Session, task_id: str, rehearsal_log: list
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新排练日志"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
task.rehearsal_log = rehearsal_log if rehearsal_log else []
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def append_rehearsal_log(
|
||||
db: Session, task_id: str, log_entry: dict
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""追加排练日志条目"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
current_log = task.rehearsal_log or []
|
||||
if isinstance(current_log, list):
|
||||
current_log.append(log_entry)
|
||||
else:
|
||||
current_log = [log_entry]
|
||||
task.rehearsal_log = current_log
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def update_branches(
|
||||
db: Session, task_id: str, branches
|
||||
) -> Optional[MultiAgentTask]:
|
||||
"""更新任务分支数据
|
||||
支持两种格式:
|
||||
- list: 旧格式,直接覆盖
|
||||
- dict: 新格式 { flow_branches: [...], task_process_branches: {...} }
|
||||
两个 key 独立保存,互不干扰。
|
||||
"""
|
||||
import copy
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
if isinstance(branches, dict):
|
||||
# 新格式:字典,独立保存两个 key,互不干扰
|
||||
# 使用深拷贝避免引用共享问题
|
||||
existing = copy.deepcopy(task.branches) if task.branches else {}
|
||||
if isinstance(existing, dict):
|
||||
# 如果只更新 flow_branches,保留已有的 task_process_branches
|
||||
if 'flow_branches' in branches and 'task_process_branches' not in branches:
|
||||
branches['task_process_branches'] = existing.get('task_process_branches', {})
|
||||
# 如果只更新 task_process_branches,保留已有的 flow_branches
|
||||
if 'task_process_branches' in branches and 'flow_branches' not in branches:
|
||||
branches['flow_branches'] = existing.get('flow_branches', [])
|
||||
task.branches = branches
|
||||
else:
|
||||
# 旧格式:列表
|
||||
task.branches = branches if branches else []
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_branches(db: Session, task_id: str) -> Optional[list]:
|
||||
"""获取任务分支数据"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
return task.branches or []
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_status(
|
||||
db: Session, status: str, limit: int = 50, offset: int = 0
|
||||
) -> List[MultiAgentTask]:
|
||||
"""根据状态获取任务记录"""
|
||||
return (
|
||||
db.query(MultiAgentTask)
|
||||
.filter(MultiAgentTask.status == status)
|
||||
.order_by(MultiAgentTask.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_by_generation_id(
|
||||
db: Session, generation_id: str
|
||||
) -> List[MultiAgentTask]:
|
||||
"""根据生成 ID 获取任务记录"""
|
||||
return (
|
||||
db.query(MultiAgentTask)
|
||||
.filter(MultiAgentTask.generation_id == generation_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_by_execution_id(
|
||||
db: Session, execution_id: str
|
||||
) -> List[MultiAgentTask]:
|
||||
"""根据执行 ID 获取任务记录"""
|
||||
return (
|
||||
db.query(MultiAgentTask)
|
||||
.filter(MultiAgentTask.execution_id == execution_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, task_id: str) -> bool:
|
||||
"""删除任务记录"""
|
||||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||||
if task:
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class UserAgentCRUD:
|
||||
"""用户智能体配置 CRUD 操作"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
agent_name: str,
|
||||
agent_config: dict,
|
||||
) -> UserAgent:
|
||||
"""创建用户智能体配置"""
|
||||
agent = UserAgent(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
agent_name=agent_name,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]:
|
||||
"""根据 ID 获取配置"""
|
||||
return db.query(UserAgent).filter(UserAgent.id == agent_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_by_user_id(
|
||||
db: Session, user_id: str, limit: int = 50
|
||||
) -> List[UserAgent]:
|
||||
"""根据用户 ID 获取所有智能体配置"""
|
||||
return (
|
||||
db.query(UserAgent)
|
||||
.filter(UserAgent.user_id == user_id)
|
||||
.order_by(UserAgent.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(
|
||||
db: Session, user_id: str, agent_name: str
|
||||
) -> List[UserAgent]:
|
||||
"""根据用户 ID 和智能体名称获取配置"""
|
||||
return (
|
||||
db.query(UserAgent)
|
||||
.filter(
|
||||
UserAgent.user_id == user_id,
|
||||
UserAgent.agent_name == agent_name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_config(
|
||||
db: Session, agent_id: str, agent_config: dict
|
||||
) -> Optional[UserAgent]:
|
||||
"""更新智能体配置"""
|
||||
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
|
||||
if agent:
|
||||
agent.agent_config = agent_config
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, agent_id: str) -> bool:
|
||||
"""删除智能体配置"""
|
||||
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
|
||||
if agent:
|
||||
db.delete(agent)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def upsert(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
agent_name: str,
|
||||
agent_config: dict,
|
||||
) -> UserAgent:
|
||||
"""更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性)
|
||||
|
||||
如果已存在相同 user_id 和 agent_name 的记录,则更新配置;
|
||||
否则创建新记录。
|
||||
"""
|
||||
existing = (
|
||||
db.query(UserAgent)
|
||||
.filter(
|
||||
UserAgent.user_id == user_id,
|
||||
UserAgent.agent_name == agent_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.agent_config = agent_config
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
# 创建新记录
|
||||
agent = UserAgent(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
agent_name=agent_name,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
return agent
|
||||
95
backend/db/database.py
Normal file
95
backend/db/database.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
数据库连接管理模块
|
||||
使用 SQLAlchemy ORM,支持同步操作
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from typing import Generator
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
|
||||
|
||||
# 读取配置
|
||||
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
||||
try:
|
||||
with open(yaml_file, "r", encoding="utf-8") as file:
|
||||
config = yaml.safe_load(file).get("database", {})
|
||||
except Exception:
|
||||
config = {}
|
||||
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""获取数据库连接 URL"""
|
||||
# 优先使用环境变量
|
||||
host = os.getenv("DB_HOST", config.get("host", "localhost"))
|
||||
port = os.getenv("DB_PORT", config.get("port", "5432"))
|
||||
user = os.getenv("DB_USER", config.get("username", "postgres"))
|
||||
password = os.getenv("DB_PASSWORD", config.get("password", ""))
|
||||
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
|
||||
|
||||
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
||||
|
||||
|
||||
# 创建引擎
|
||||
DATABASE_URL = get_database_url()
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=config.get("pool_size", 10),
|
||||
max_overflow=config.get("max_overflow", 20),
|
||||
pool_pre_ping=True,
|
||||
echo=False,
|
||||
# JSONB 类型处理器配置
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 基础类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
"""
|
||||
获取数据库会话
|
||||
用法: for db in get_db(): ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_context() -> Generator:
|
||||
"""
|
||||
上下文管理器方式获取数据库会话
|
||||
用法: with get_db_context() as db: ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_connection() -> bool:
|
||||
"""测试数据库连接"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
22
backend/db/init_db.py
Normal file
22
backend/db/init_db.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
数据库初始化脚本
|
||||
运行此脚本创建所有表结构
|
||||
基于 DATABASE_DESIGN.md 设计
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from db.database import engine, Base
|
||||
from db.models import MultiAgentTask, UserAgent
|
||||
|
||||
|
||||
def init_database():
|
||||
"""初始化数据库表结构"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_database()
|
||||
104
backend/db/models.py
Normal file
104
backend/db/models.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
SQLAlchemy ORM 数据模型
|
||||
对应数据库表结构 (基于 DATABASE_DESIGN.md)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum as PyEnum
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer, Enum, Index, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .database import Base
|
||||
|
||||
|
||||
class TaskStatus(str, PyEnum):
|
||||
"""任务状态枚举"""
|
||||
GENERATING = "generating" # 生成中 - TaskProcess 生成阶段
|
||||
EXECUTING = "executing" # 执行中 - 任务执行阶段
|
||||
STOPPED = "stopped" # 已停止 - 用户手动停止执行
|
||||
COMPLETED = "completed" # 已完成 - 任务正常完成
|
||||
|
||||
|
||||
def utc_now():
|
||||
"""获取当前 UTC 时间"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class MultiAgentTask(Base):
|
||||
"""多智能体任务记录模型"""
|
||||
__tablename__ = "multi_agent_tasks"
|
||||
|
||||
task_id = Column(String(64), primary_key=True)
|
||||
user_id = Column(String(64), nullable=False, index=True)
|
||||
query = Column(Text, nullable=False)
|
||||
agents_info = Column(JSONB, nullable=False)
|
||||
task_outline = Column(JSONB)
|
||||
assigned_agents = Column(JSONB)
|
||||
agent_scores = Column(JSONB)
|
||||
result = Column(JSONB)
|
||||
status = Column(
|
||||
Enum(TaskStatus, name="task_status_enum", create_type=False),
|
||||
default=TaskStatus.GENERATING,
|
||||
nullable=False
|
||||
)
|
||||
execution_count = Column(Integer, default=0, nullable=False)
|
||||
generation_id = Column(String(64))
|
||||
execution_id = Column(String(64))
|
||||
rehearsal_log = Column(JSONB)
|
||||
branches = Column(JSONB) # 任务大纲探索分支数据
|
||||
created_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_multi_agent_tasks_status", "status"),
|
||||
Index("idx_multi_agent_tasks_generation_id", "generation_id"),
|
||||
Index("idx_multi_agent_tasks_execution_id", "execution_id"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"task_id": self.task_id,
|
||||
"user_id": self.user_id,
|
||||
"query": self.query,
|
||||
"agents_info": self.agents_info,
|
||||
"task_outline": self.task_outline,
|
||||
"assigned_agents": self.assigned_agents,
|
||||
"agent_scores": self.agent_scores,
|
||||
"result": self.result,
|
||||
"status": self.status.value if self.status else None,
|
||||
"execution_count": self.execution_count,
|
||||
"generation_id": self.generation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"rehearsal_log": self.rehearsal_log,
|
||||
"branches": self.branches,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class UserAgent(Base):
|
||||
"""用户保存的智能体配置模型 (可选表)"""
|
||||
__tablename__ = "user_agents"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
user_id = Column(String(64), nullable=False, index=True)
|
||||
agent_name = Column(String(100), nullable=False)
|
||||
agent_config = Column(JSONB, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_user_agents_user_created", "user_id", "created_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
"agent_name": self.agent_name,
|
||||
"agent_config": self.agent_config,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
70
backend/db/schema.sql
Normal file
70
backend/db/schema.sql
Normal file
@@ -0,0 +1,70 @@
|
||||
-- AgentCoord 数据库表结构
|
||||
-- 基于 DATABASE_DESIGN.md 设计
|
||||
-- 执行方式: psql -U postgres -d agentcoord -f schema.sql
|
||||
|
||||
-- =============================================================================
|
||||
-- 表1: multi_agent_tasks (多智能体任务记录)
|
||||
-- 状态枚举: pending/planning/generating/executing/completed/failed
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS multi_agent_tasks (
|
||||
task_id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
query TEXT NOT NULL,
|
||||
agents_info JSONB NOT NULL,
|
||||
task_outline JSONB,
|
||||
assigned_agents JSONB,
|
||||
agent_scores JSONB,
|
||||
result JSONB,
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
execution_count INTEGER DEFAULT 0,
|
||||
generation_id VARCHAR(64),
|
||||
execution_id VARCHAR(64),
|
||||
rehearsal_log JSONB,
|
||||
branches JSONB, -- 任务大纲探索分支数据
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_user_id ON multi_agent_tasks(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_created_at ON multi_agent_tasks(created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_status ON multi_agent_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_generation_id ON multi_agent_tasks(generation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_execution_id ON multi_agent_tasks(execution_id);
|
||||
|
||||
-- =============================================================================
|
||||
-- 表2: user_agents (用户保存的智能体配置) - 可选表
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS user_agents (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
agent_name VARCHAR(100) NOT NULL,
|
||||
agent_config JSONB NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_agents_user_id ON user_agents(user_id);
|
||||
|
||||
-- =============================================================================
|
||||
-- 更新时间触发器函数
|
||||
-- =============================================================================
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = CURRENT_TIMESTAMP;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ language 'plpgsql';
|
||||
|
||||
-- 为 multi_agent_tasks 表创建触发器
|
||||
CREATE TRIGGER update_multi_agent_tasks_updated_at
|
||||
BEFORE UPDATE ON multi_agent_tasks
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
RAISE NOTICE '✅ PostgreSQL 数据库表结构创建完成!';
|
||||
RAISE NOTICE '表: multi_agent_tasks (多智能体任务记录)';
|
||||
RAISE NOTICE '表: user_agents (用户智能体配置)';
|
||||
END $$;
|
||||
1386
backend/server.py
1386
backend/server.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user