feat:1.数据库存储功能添加(初版)2.后端REST API版本代码清理

This commit is contained in:
liailing1026
2026-02-25 10:55:51 +08:00
parent f736cd104a
commit 2140cfaf92
35 changed files with 3912 additions and 2981 deletions

View File

@@ -49,7 +49,6 @@ def LLM_Completion(
messages: list[dict], stream: bool = True, useGroq: bool = True,model_config: dict = None
) -> str:
if model_config:
print_colored(f"Using model config: {model_config}", "blue")
return _call_with_custom_config(messages,stream,model_config)
if not useGroq or not FAST_DESIGN_MODE:
force_gpt4 = True
@@ -125,7 +124,7 @@ def _call_with_custom_config(messages: list[dict], stream: bool, model_config: d
#print(colored(full_reply_content, "blue", "on_white"), end="")
return full_reply_content
except Exception as e:
print_colored(f"Custom API error for model {api_model} :{str(e)}","red")
print_colored(f"[API Error] Custom API error: {str(e)}", "red")
raise
@@ -166,11 +165,11 @@ async def _achat_completion_stream_custom(messages:list[dict], temp_async_client
except httpx.RemoteProtocolError as e:
if attempt < max_retries - 1:
wait_time = (attempt + 1) *2
print_colored(f"⚠️ Stream connection interrupted (attempt {attempt+1}/{max_retries}). Retrying in {wait_time}s...", text_color="yellow")
print_colored(f"[API Warn] Stream connection interrupted, retrying in {wait_time}s...", "yellow")
await asyncio.sleep(wait_time)
continue
except Exception as e:
print_colored(f"Custom API stream error for model {api_model} :{str(e)}","red")
print_colored(f"[API Error] Custom API stream error: {str(e)}", "red")
raise
@@ -181,7 +180,6 @@ async def _achat_completion_stream_groq(messages: list[dict]) -> str:
max_attempts = 5
for attempt in range(max_attempts):
print("Attempt to use Groq (Fase Design Mode):")
try:
response = await groq_client.chat.completions.create(
messages=messages,
@@ -363,7 +361,7 @@ async def _achat_completion_stream(messages: list[dict]) -> str:
return full_reply_content
except Exception as e:
print_colored(f"OpenAI API error in _achat_completion_stream: {str(e)}", "red")
print_colored(f"[API Error] OpenAI API stream error: {str(e)}", "red")
raise
@@ -383,7 +381,7 @@ def _chat_completion(messages: list[dict]) -> str:
return content
except Exception as e:
print_colored(f"OpenAI API error in _chat_completion: {str(e)}", "red")
print_colored(f"[API Error] OpenAI API error: {str(e)}", "red")
raise

View File

@@ -77,7 +77,6 @@ def generate_AbilityRequirement(General_Goal, Current_Task):
),
},
]
print(messages[1]["content"])
return read_LLM_Completion(messages)["AbilityRequirement"]
@@ -106,6 +105,7 @@ def generate_AgentSelection(General_Goal, Current_Task, Agent_Board):
while True:
candidate = read_LLM_Completion(messages)["AgentSelectionPlan"]
# 添加调试打印
if len(candidate) > MAX_TEAM_SIZE:
teamSize = random.randint(2, MAX_TEAM_SIZE)
candidate = candidate[0:teamSize]
@@ -113,7 +113,6 @@ def generate_AgentSelection(General_Goal, Current_Task, Agent_Board):
continue
AgentSelectionPlan = sorted(candidate)
AgentSelectionPlan_set = set(AgentSelectionPlan)
# Check if every item in AgentSelectionPlan is in agentboard
if AgentSelectionPlan_set.issubset(agentboard_set):
break # If all items are in agentboard, break the loop

View File

@@ -85,7 +85,6 @@ class BaseAction():
# Handle missing agent profiles gracefully
model_config = None
if agentName not in AgentProfile_Dict:
print_colored(text=f"Warning: Agent '{agentName}' not found in AgentProfile_Dict. Using default profile.", text_color="yellow")
agentProfile = f"AI Agent named {agentName}"
else:
# agentProfile = AgentProfile_Dict[agentName]

View File

@@ -83,7 +83,6 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
}
# start the group chat
util.print_colored(TaskDescription, text_color="green")
ActionHistory = []
action_count = 0
total_actions = len(TaskProcess)
@@ -93,9 +92,6 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
actionType = ActionInfo["ActionType"]
agentName = ActionInfo["AgentName"]
# 添加进度日志
util.print_colored(f"🔄 Executing action {action_count}/{total_actions}: {actionType} by {agentName}", text_color="yellow")
if actionType in Action.customAction_Dict:
currentAction = Action.customAction_Dict[actionType](
info=ActionInfo,
@@ -126,11 +122,4 @@ def executePlan(plan, num_StepToRun, RehearsalLog, AgentProfile_Dict):
stepLogNode["ActionHistory"] = ActionHistory
# Return Output
print(
colored(
"$Run " + str(StepRun_count) + "step$",
color="black",
on_color="on_white",
)
)
return RehearsalLog

View File

@@ -107,7 +107,6 @@ def get_parallel_batches(TaskProcess: List[Dict], dependency_map: Dict[int, List
# 避免死循环
remaining = [i for i in range(len(TaskProcess)) if i not in completed]
if remaining:
print(colored(f"警告: 检测到循环依赖,强制串行执行: {remaining}", "yellow"))
ready_to_run = remaining[:1]
else:
break
@@ -188,7 +187,8 @@ async def execute_step_async_streaming(
KeyObjects: Dict,
step_index: int,
total_steps: int,
execution_id: str = None
execution_id: str = None,
RehearsalLog: List = None # 用于追加日志到历史记录
) -> Generator[Dict, None, None]:
"""
异步执行单个步骤,支持流式返回
@@ -276,8 +276,9 @@ async def execute_step_async_streaming(
total_actions = len(TaskProcess)
completed_actions = 0
# 步骤开始日志
util.print_colored(
f"📋 步骤 {step_index + 1}/{total_steps}: {StepName} ({total_actions} 个动作, 分 {len(batches)}并行执行)",
f"📋 步骤 {step_index + 1}/{total_steps}: {StepName} ({total_actions} 个动作, 分 {len(batches)} 批执行)",
text_color="cyan"
)
@@ -286,11 +287,11 @@ async def execute_step_async_streaming(
# 在每个批次执行前检查暂停状态
should_continue = await execution_state_manager.async_check_pause(execution_id)
if not should_continue:
util.print_colored("🛑 用户请求停止执行", "red")
return
batch_size = len(batch_indices)
# 批次执行日志
if batch_size > 1:
util.print_colored(
f"🚦 批次 {batch_index + 1}/{len(batches)}: 并行执行 {batch_size} 个动作",
@@ -298,7 +299,7 @@ async def execute_step_async_streaming(
)
else:
util.print_colored(
f"🔄 动作 {completed_actions + 1}/{total_actions}: 串行执行",
f"🔄 批次 {batch_index + 1}/{len(batches)}: 串行执行",
text_color="yellow"
)
@@ -353,12 +354,21 @@ async def execute_step_async_streaming(
objectLogNode["content"] = KeyObjects[OutputName]
stepLogNode["ActionHistory"] = ActionHistory
# 收集该步骤使用的 agent去重
assigned_agents_in_step = list(set(Agent_List)) if Agent_List else []
# 追加到 RehearsalLog因为 RehearsalLog 是可变对象,会反映到原列表)
if RehearsalLog is not None:
RehearsalLog.append(stepLogNode)
RehearsalLog.append(objectLogNode)
yield {
"type": "step_complete",
"step_index": step_index,
"step_name": StepName,
"step_log_node": stepLogNode,
"object_log_node": objectLogNode,
"assigned_agents": {StepName: assigned_agents_in_step}, # 该步骤使用的 agent
}
@@ -394,8 +404,6 @@ def executePlan_streaming_dynamic(
execution_state_manager.start_execution(execution_id, general_goal)
print(colored(f"⏸️ 执行状态管理器已启动,支持暂停/恢复execution_id={execution_id}", "green"))
# 准备执行
KeyObjects = existingKeyObjects.copy() if existingKeyObjects else {}
finishedStep_index = -1
@@ -406,9 +414,6 @@ def executePlan_streaming_dynamic(
if logNode["LogNodeType"] == "object":
KeyObjects[logNode["NodeId"]] = logNode["content"]
if existingKeyObjects:
print(colored(f"📦 使用已存在的 KeyObjects: {list(existingKeyObjects.keys())}", "cyan"))
# 确定要运行的步骤范围
if num_StepToRun is None:
run_to = len(plan["Collaboration Process"])
@@ -421,9 +426,6 @@ def executePlan_streaming_dynamic(
if execution_id:
# 初始化执行管理器使用传入的execution_id
actual_execution_id = dynamic_execution_manager.start_execution(general_goal, steps_to_run, execution_id)
print(colored(f"🚀 开始执行计划(动态模式),共 {len(steps_to_run)} 个步骤执行ID: {actual_execution_id}", "cyan"))
else:
print(colored(f"🚀 开始执行计划(流式推送),共 {len(steps_to_run)} 个步骤", "cyan"))
total_steps = len(steps_to_run)
@@ -443,7 +445,7 @@ def executePlan_streaming_dynamic(
# 检查暂停状态
should_continue = await execution_state_manager.async_check_pause(execution_id)
if not should_continue:
print(colored("🛑 用户请求停止执行", "red"))
util.print_colored("🛑 用户请求停止执行", "red")
await queue.put({
"type": "error",
"message": "执行已被用户停止"
@@ -466,29 +468,24 @@ def executePlan_streaming_dynamic(
# 如果没有步骤在队列中queue_total_steps为0立即退出
if queue_total_steps == 0:
print(colored(f"⚠️ 没有步骤在队列中,退出执行", "yellow"))
break
# 如果所有步骤都已完成,等待可能的新步骤
if completed_steps >= queue_total_steps:
if empty_wait_count >= max_empty_wait_cycles:
# 等待超时,退出执行
print(colored(f"✅ 所有步骤执行完成,等待超时", "green"))
break
else:
# 等待新步骤追加
print(colored(f"⏳ 等待新步骤追加... ({empty_wait_count}/{max_empty_wait_cycles})", "cyan"))
await asyncio.sleep(1)
continue
else:
# 还有步骤未完成,继续尝试获取
print(colored(f"⏳ 等待步骤就绪... ({completed_steps}/{queue_total_steps})", "cyan"))
await asyncio.sleep(0.5)
empty_wait_count = 0 # 重置等待计数
continue
else:
# 执行信息不存在,退出
print(colored(f"⚠️ 执行信息不存在,退出执行", "yellow"))
break
# 重置等待计数
@@ -506,7 +503,8 @@ def executePlan_streaming_dynamic(
KeyObjects,
step_index,
current_total_steps, # 使用动态更新的总步骤数
execution_id
execution_id,
RehearsalLog # 传递 RehearsalLog 用于追加日志
):
if execution_state_manager.is_stopped(execution_id):
await queue.put({
@@ -533,7 +531,7 @@ def executePlan_streaming_dynamic(
for step_index, stepDescrip in enumerate(steps_to_run):
should_continue = await execution_state_manager.async_check_pause(execution_id)
if not should_continue:
print(colored("🛑 用户请求停止执行", "red"))
util.print_colored("🛑 用户请求停止执行", "red")
await queue.put({
"type": "error",
"message": "执行已被用户停止"
@@ -547,7 +545,8 @@ def executePlan_streaming_dynamic(
KeyObjects,
step_index,
total_steps,
execution_id
execution_id,
RehearsalLog # 传递 RehearsalLog 用于追加日志
):
if execution_state_manager.is_stopped(execution_id):
await queue.put({

View File

@@ -63,10 +63,7 @@ class DynamicExecutionManager:
# 初始化待执行步骤索引
self._pending_steps[execution_id] = list(range(len(initial_steps)))
print(f"🚀 启动执行: {execution_id}")
print(f"📊 初始步骤数: {len(initial_steps)}")
print(f"📋 待执行步骤索引: {self._pending_steps[execution_id]}")
print(f"[Execution] 启动执行: {execution_id}, 步骤数={len(initial_steps)}")
return execution_id
def add_steps(self, execution_id: str, new_steps: List[Dict]) -> int:
@@ -82,7 +79,6 @@ class DynamicExecutionManager:
"""
with self._lock:
if execution_id not in self._step_queues:
print(f"⚠️ 警告: 执行ID {execution_id} 不存在,无法追加步骤")
return 0
current_count = len(self._step_queues[execution_id])
@@ -95,14 +91,9 @@ class DynamicExecutionManager:
self._pending_steps[execution_id].extend(new_indices)
# 更新总步骤数
old_total = self._executions[execution_id]["total_steps"]
self._executions[execution_id]["total_steps"] = len(self._step_queues[execution_id])
new_total = self._executions[execution_id]["total_steps"]
print(f" 追加了 {len(new_steps)} 个步骤到 {execution_id}")
print(f"📊 步骤总数: {old_total} -> {new_total}")
print(f"📋 待执行步骤索引: {self._pending_steps[execution_id]}")
print(f"[Execution] 追加步骤: +{len(new_steps)}, 总计={self._executions[execution_id]['total_steps']}")
return len(new_steps)
def get_next_step(self, execution_id: str) -> Optional[Dict]:
@@ -117,7 +108,6 @@ class DynamicExecutionManager:
"""
with self._lock:
if execution_id not in self._pending_steps:
print(f"⚠️ 警告: 执行ID {execution_id} 不存在")
return None
# 获取第一个待执行步骤的索引
@@ -128,7 +118,6 @@ class DynamicExecutionManager:
# 从队列中获取步骤
if step_index >= len(self._step_queues[execution_id]):
print(f"⚠️ 警告: 步骤索引 {step_index} 超出范围")
return None
step = self._step_queues[execution_id][step_index]
@@ -137,9 +126,7 @@ class DynamicExecutionManager:
self._executed_steps[execution_id].add(step_index)
step_name = step.get("StepName", "未知")
print(f"🎯 获取下一个步骤: {step_name} (索引: {step_index})")
print(f"📋 剩余待执行步骤: {len(self._pending_steps[execution_id])}")
print(f"[Execution] 获取步骤: {step_name} (索引: {step_index})")
return step
def mark_step_completed(self, execution_id: str):
@@ -154,9 +141,7 @@ class DynamicExecutionManager:
self._executions[execution_id]["completed_steps"] += 1
completed = self._executions[execution_id]["completed_steps"]
total = self._executions[execution_id]["total_steps"]
print(f"📊 步骤完成进度: {completed}/{total}")
else:
print(f"⚠️ 警告: 执行ID {execution_id} 不存在")
print(f"[Execution] 步骤完成: {completed}/{total}")
def get_execution_info(self, execution_id: str) -> Optional[Dict]:
"""

View File

@@ -129,7 +129,6 @@ class ExecutionStateManager:
state['goal'] = goal
state['should_pause'] = False
state['should_stop'] = False
print(f"🚀 [DEBUG] start_execution: execution_id={execution_id}, 状态设置为 RUNNING, goal={goal}")
def pause_execution(self, execution_id: str) -> bool:
"""
@@ -143,19 +142,13 @@ class ExecutionStateManager:
"""
state = self._get_state(execution_id)
if state is None:
# 打印当前所有活跃的 execution_id帮助调试
active_ids = list(self._states.keys())
print(f"⚠️ [DEBUG] pause_execution: execution_id={execution_id} 不存在")
print(f" 当前活跃的 execution_id 列表: {active_ids}")
return False
with self._get_lock(execution_id):
if state['status'] != ExecutionStatus.RUNNING:
print(f"⚠️ [DEBUG] pause_execution: execution_id={execution_id}, 当前状态是 {state['status']},无法暂停")
return False
state['status'] = ExecutionStatus.PAUSED
state['should_pause'] = True
print(f"⏸️ [DEBUG] pause_execution: execution_id={execution_id}, 状态设置为PAUSED")
return True
def resume_execution(self, execution_id: str) -> bool:
@@ -170,16 +163,13 @@ class ExecutionStateManager:
"""
state = self._get_state(execution_id)
if state is None:
print(f"⚠️ [DEBUG] resume_execution: execution_id={execution_id} 不存在")
return False
with self._get_lock(execution_id):
if state['status'] != ExecutionStatus.PAUSED:
print(f"⚠️ [DEBUG] resume_execution: 当前状态不是PAUSED而是 {state['status']}")
return False
state['status'] = ExecutionStatus.RUNNING
state['should_pause'] = False
print(f"▶️ [DEBUG] resume_execution: execution_id={execution_id}, 状态设置为RUNNING, should_pause=False")
return True
def stop_execution(self, execution_id: str) -> bool:
@@ -194,17 +184,14 @@ class ExecutionStateManager:
"""
state = self._get_state(execution_id)
if state is None:
print(f"⚠️ [DEBUG] stop_execution: execution_id={execution_id} 不存在")
return False
with self._get_lock(execution_id):
if state['status'] in [ExecutionStatus.IDLE, ExecutionStatus.STOPPED]:
print(f"⚠️ [DEBUG] stop_execution: 当前状态是 {state['status']}, 无法停止")
return False
state['status'] = ExecutionStatus.STOPPED
state['should_stop'] = True
state['should_pause'] = False
print(f"🛑 [DEBUG] stop_execution: execution_id={execution_id}, 状态设置为STOPPED")
return True
def reset(self, execution_id: str):
@@ -215,12 +202,10 @@ class ExecutionStateManager:
state['goal'] = None
state['should_pause'] = False
state['should_stop'] = False
print(f"🔄 [DEBUG] reset: execution_id={execution_id}, 状态重置为IDLE")
def cleanup(self, execution_id: str):
"""清理指定 execution_id 的所有状态"""
self._cleanup_state(execution_id)
print(f"🧹 [DEBUG] cleanup: execution_id={execution_id} 的状态已清理")
async def async_check_pause(self, execution_id: str):
"""
@@ -248,7 +233,6 @@ class ExecutionStateManager:
# 检查停止标志
if should_stop:
print("🛑 [DEBUG] async_check_pause: execution_id={}, 检测到停止信号".format(execution_id))
return False
# 检查暂停状态
@@ -262,7 +246,6 @@ class ExecutionStateManager:
should_stop = state['should_stop']
if not should_pause:
print("▶️ [DEBUG] async_check_pause: execution_id={}, 从暂停中恢复!".format(execution_id))
continue
if should_stop:
return False

View File

@@ -113,7 +113,6 @@ class GenerationStateManager:
state['status'] = GenerationStatus.GENERATING
state['goal'] = goal
state['should_stop'] = False
print(f"🚀 [GenerationState] start_generation: generation_id={generation_id}, 状态设置为 GENERATING")
def stop_generation(self, generation_id: str) -> bool:
"""
@@ -127,26 +126,21 @@ class GenerationStateManager:
"""
state = self._get_state(generation_id)
if state is None:
print(f"⚠️ [GenerationState] stop_generation: generation_id={generation_id} 不存在")
return True # 不存在也算停止成功
with self._get_lock(generation_id):
if state['status'] == GenerationStatus.STOPPED:
print(f"✅ [GenerationState] stop_generation: generation_id={generation_id} 已经是 STOPPED 状态")
return True # 已经停止也算成功
if state['status'] == GenerationStatus.COMPLETED:
print(f"✅ [GenerationState] stop_generation: generation_id={generation_id} 已经 COMPLETED视为停止成功")
return True # 已完成也视为停止成功
if state['status'] == GenerationStatus.IDLE:
print(f"⚠️ [GenerationState] stop_generation: generation_id={generation_id} 是 IDLE 状态,无需停止")
return True # 空闲状态也视为无需停止
# 真正需要停止的情况
state['status'] = GenerationStatus.STOPPED
state['should_stop'] = True
print(f"🛑 [GenerationState] stop_generation: generation_id={generation_id}, 状态设置为STOPPED")
return True
def complete_generation(self, generation_id: str):
@@ -154,12 +148,10 @@ class GenerationStateManager:
state = self._ensure_state(generation_id)
with self._get_lock(generation_id):
state['status'] = GenerationStatus.COMPLETED
print(f"✅ [GenerationState] complete_generation: generation_id={generation_id}")
def cleanup(self, generation_id: str):
"""清理指定 generation_id 的所有状态"""
self._cleanup_state(generation_id)
print(f"🧹 [GenerationState] cleanup: generation_id={generation_id} 的状态已清理")
def should_stop(self, generation_id: str) -> bool:
"""检查是否应该停止"""

View File

@@ -94,6 +94,7 @@ def remove_render_spec(duty_spec):
return duty_spec
def read_LLM_Completion(messages, useGroq=True):
for _ in range(3):
text = LLM_Completion(messages, useGroq=useGroq)

25
backend/db/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""
AgentCoord 数据库模块
提供 PostgreSQL 数据库连接、模型和 CRUD 操作
基于 DATABASE_DESIGN.md 设计
"""
from .database import get_db, get_db_context, test_connection, engine, text
from .models import MultiAgentTask, UserAgent, TaskStatus
from .crud import MultiAgentTaskCRUD, UserAgentCRUD
__all__ = [
# 连接管理
"get_db",
"get_db_context",
"test_connection",
"engine",
"text",
# 模型
"MultiAgentTask",
"UserAgent",
"TaskStatus",
# CRUD
"MultiAgentTaskCRUD",
"UserAgentCRUD",
]

404
backend/db/crud.py Normal file
View File

@@ -0,0 +1,404 @@
"""
数据库 CRUD 操作
封装所有数据库操作方法 (基于 DATABASE_DESIGN.md)
"""
import copy
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
from .models import MultiAgentTask, UserAgent
class MultiAgentTaskCRUD:
"""多智能体任务 CRUD 操作"""
@staticmethod
def create(
db: Session,
task_id: Optional[str] = None, # 可选,如果为 None 则自动生成
user_id: str = "",
query: str = "",
agents_info: list = [],
task_outline: Optional[dict] = None,
assigned_agents: Optional[list] = None,
agent_scores: Optional[dict] = None,
result: Optional[str] = None,
) -> MultiAgentTask:
"""创建任务记录"""
task = MultiAgentTask(
task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的
user_id=user_id,
query=query,
agents_info=agents_info,
task_outline=task_outline,
assigned_agents=assigned_agents,
agent_scores=agent_scores,
result=result,
)
db.add(task)
db.commit()
db.refresh(task)
return task
@staticmethod
def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""根据任务 ID 获取记录"""
return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据用户 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.user_id == user_id)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_recent(
db: Session, limit: int = 20, offset: int = 0
) -> List[MultiAgentTask]:
"""获取最近的任务记录"""
return (
db.query(MultiAgentTask)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def update_result(
db: Session, task_id: str, result: list
) -> Optional[MultiAgentTask]:
"""更新任务结果"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.result = result if result else []
db.commit()
db.refresh(task)
return task
@staticmethod
def update_task_outline(
db: Session, task_id: str, task_outline: dict
) -> Optional[MultiAgentTask]:
"""更新任务大纲"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.task_outline = task_outline
db.commit()
db.refresh(task)
return task
@staticmethod
def update_assigned_agents(
db: Session, task_id: str, assigned_agents: dict
) -> Optional[MultiAgentTask]:
"""更新分配的智能体(步骤名 -> agent列表"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.assigned_agents = assigned_agents
db.commit()
db.refresh(task)
return task
@staticmethod
def update_agent_scores(
db: Session, task_id: str, agent_scores: dict
) -> Optional[MultiAgentTask]:
"""更新智能体评分(合并模式,追加新步骤的评分)"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
# 合并现有评分数据和新评分数据
existing_scores = task.agent_scores or {}
merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据
task.agent_scores = merged_scores
db.commit()
db.refresh(task)
return task
@staticmethod
def update_status(
db: Session, task_id: str, status: str
) -> Optional[MultiAgentTask]:
"""更新任务状态"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.status = status
db.commit()
db.refresh(task)
return task
@staticmethod
def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""增加任务执行次数"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_count = (task.execution_count or 0) + 1
db.commit()
db.refresh(task)
return task
@staticmethod
def update_generation_id(
db: Session, task_id: str, generation_id: str
) -> Optional[MultiAgentTask]:
"""更新生成 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.generation_id = generation_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_execution_id(
db: Session, task_id: str, execution_id: str
) -> Optional[MultiAgentTask]:
"""更新执行 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_id = execution_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_rehearsal_log(
db: Session, task_id: str, rehearsal_log: list
) -> Optional[MultiAgentTask]:
"""更新排练日志"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.rehearsal_log = rehearsal_log if rehearsal_log else []
db.commit()
db.refresh(task)
return task
@staticmethod
def append_rehearsal_log(
db: Session, task_id: str, log_entry: dict
) -> Optional[MultiAgentTask]:
"""追加排练日志条目"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
current_log = task.rehearsal_log or []
if isinstance(current_log, list):
current_log.append(log_entry)
else:
current_log = [log_entry]
task.rehearsal_log = current_log
db.commit()
db.refresh(task)
return task
@staticmethod
def update_branches(
db: Session, task_id: str, branches
) -> Optional[MultiAgentTask]:
"""更新任务分支数据
支持两种格式:
- list: 旧格式,直接覆盖
- dict: 新格式 { flow_branches: [...], task_process_branches: {...} }
两个 key 独立保存,互不干扰。
"""
import copy
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
if isinstance(branches, dict):
# 新格式:字典,独立保存两个 key互不干扰
# 使用深拷贝避免引用共享问题
existing = copy.deepcopy(task.branches) if task.branches else {}
if isinstance(existing, dict):
# 如果只更新 flow_branches保留已有的 task_process_branches
if 'flow_branches' in branches and 'task_process_branches' not in branches:
branches['task_process_branches'] = existing.get('task_process_branches', {})
# 如果只更新 task_process_branches保留已有的 flow_branches
if 'task_process_branches' in branches and 'flow_branches' not in branches:
branches['flow_branches'] = existing.get('flow_branches', [])
task.branches = branches
else:
# 旧格式:列表
task.branches = branches if branches else []
db.commit()
db.refresh(task)
return task
@staticmethod
def get_branches(db: Session, task_id: str) -> Optional[list]:
"""获取任务分支数据"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
return task.branches or []
return []
@staticmethod
def get_by_status(
db: Session, status: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据状态获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.status == status)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_by_generation_id(
db: Session, generation_id: str
) -> List[MultiAgentTask]:
"""根据生成 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.generation_id == generation_id)
.all()
)
@staticmethod
def get_by_execution_id(
db: Session, execution_id: str
) -> List[MultiAgentTask]:
"""根据执行 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.execution_id == execution_id)
.all()
)
@staticmethod
def delete(db: Session, task_id: str) -> bool:
"""删除任务记录"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
db.delete(task)
db.commit()
return True
return False
class UserAgentCRUD:
"""用户智能体配置 CRUD 操作"""
@staticmethod
def create(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""创建用户智能体配置"""
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent
@staticmethod
def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]:
"""根据 ID 获取配置"""
return db.query(UserAgent).filter(UserAgent.id == agent_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50
) -> List[UserAgent]:
"""根据用户 ID 获取所有智能体配置"""
return (
db.query(UserAgent)
.filter(UserAgent.user_id == user_id)
.order_by(UserAgent.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_by_name(
db: Session, user_id: str, agent_name: str
) -> List[UserAgent]:
"""根据用户 ID 和智能体名称获取配置"""
return (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.all()
)
@staticmethod
def update_config(
db: Session, agent_id: str, agent_config: dict
) -> Optional[UserAgent]:
"""更新智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
agent.agent_config = agent_config
db.commit()
db.refresh(agent)
return agent
@staticmethod
def delete(db: Session, agent_id: str) -> bool:
"""删除智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
db.delete(agent)
db.commit()
return True
return False
@staticmethod
def upsert(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性)
如果已存在相同 user_id 和 agent_name 的记录,则更新配置;
否则创建新记录。
"""
existing = (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.first()
)
if existing:
# 更新现有记录
existing.agent_config = agent_config
db.commit()
db.refresh(existing)
return existing
else:
# 创建新记录
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent

95
backend/db/database.py Normal file
View File

@@ -0,0 +1,95 @@
"""
数据库连接管理模块
使用 SQLAlchemy ORM支持同步操作
"""
import os
import yaml
from typing import Generator
from contextlib import contextmanager
import json
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import QueuePool
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
# 读取配置
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
try:
with open(yaml_file, "r", encoding="utf-8") as file:
config = yaml.safe_load(file).get("database", {})
except Exception:
config = {}
def get_database_url() -> str:
"""获取数据库连接 URL"""
# 优先使用环境变量
host = os.getenv("DB_HOST", config.get("host", "localhost"))
port = os.getenv("DB_PORT", config.get("port", "5432"))
user = os.getenv("DB_USER", config.get("username", "postgres"))
password = os.getenv("DB_PASSWORD", config.get("password", ""))
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
# 创建引擎
DATABASE_URL = get_database_url()
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=config.get("pool_size", 10),
max_overflow=config.get("max_overflow", 20),
pool_pre_ping=True,
echo=False,
# JSONB 类型处理器配置
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 基础类
Base = declarative_base()
def get_db() -> Generator:
"""
获取数据库会话
用法: for db in get_db(): ...
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context() -> Generator:
"""
上下文管理器方式获取数据库会话
用法: with get_db_context() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
def test_connection() -> bool:
"""测试数据库连接"""
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
return False

22
backend/db/init_db.py Normal file
View File

@@ -0,0 +1,22 @@
"""
数据库初始化脚本
运行此脚本创建所有表结构
基于 DATABASE_DESIGN.md 设计
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from db.database import engine, Base
from db.models import MultiAgentTask, UserAgent
def init_database():
"""初始化数据库表结构"""
Base.metadata.create_all(bind=engine)
if __name__ == "__main__":
init_database()

104
backend/db/models.py Normal file
View File

@@ -0,0 +1,104 @@
"""
SQLAlchemy ORM 数据模型
对应数据库表结构 (基于 DATABASE_DESIGN.md)
"""
import uuid
from datetime import datetime, timezone
from enum import Enum as PyEnum
from sqlalchemy import Column, String, Text, DateTime, Integer, Enum, Index, ForeignKey
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from .database import Base
class TaskStatus(str, PyEnum):
"""任务状态枚举"""
GENERATING = "generating" # 生成中 - TaskProcess 生成阶段
EXECUTING = "executing" # 执行中 - 任务执行阶段
STOPPED = "stopped" # 已停止 - 用户手动停止执行
COMPLETED = "completed" # 已完成 - 任务正常完成
def utc_now():
"""获取当前 UTC 时间"""
return datetime.now(timezone.utc)
class MultiAgentTask(Base):
"""多智能体任务记录模型"""
__tablename__ = "multi_agent_tasks"
task_id = Column(String(64), primary_key=True)
user_id = Column(String(64), nullable=False, index=True)
query = Column(Text, nullable=False)
agents_info = Column(JSONB, nullable=False)
task_outline = Column(JSONB)
assigned_agents = Column(JSONB)
agent_scores = Column(JSONB)
result = Column(JSONB)
status = Column(
Enum(TaskStatus, name="task_status_enum", create_type=False),
default=TaskStatus.GENERATING,
nullable=False
)
execution_count = Column(Integer, default=0, nullable=False)
generation_id = Column(String(64))
execution_id = Column(String(64))
rehearsal_log = Column(JSONB)
branches = Column(JSONB) # 任务大纲探索分支数据
created_at = Column(DateTime(timezone=True), default=utc_now)
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
__table_args__ = (
Index("idx_multi_agent_tasks_status", "status"),
Index("idx_multi_agent_tasks_generation_id", "generation_id"),
Index("idx_multi_agent_tasks_execution_id", "execution_id"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"task_id": self.task_id,
"user_id": self.user_id,
"query": self.query,
"agents_info": self.agents_info,
"task_outline": self.task_outline,
"assigned_agents": self.assigned_agents,
"agent_scores": self.agent_scores,
"result": self.result,
"status": self.status.value if self.status else None,
"execution_count": self.execution_count,
"generation_id": self.generation_id,
"execution_id": self.execution_id,
"rehearsal_log": self.rehearsal_log,
"branches": self.branches,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class UserAgent(Base):
"""用户保存的智能体配置模型 (可选表)"""
__tablename__ = "user_agents"
id = Column(String(64), primary_key=True)
user_id = Column(String(64), nullable=False, index=True)
agent_name = Column(String(100), nullable=False)
agent_config = Column(JSONB, nullable=False)
created_at = Column(DateTime(timezone=True), default=utc_now)
__table_args__ = (
Index("idx_user_agents_user_created", "user_id", "created_at"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"user_id": self.user_id,
"agent_name": self.agent_name,
"agent_config": self.agent_config,
"created_at": self.created_at.isoformat() if self.created_at else None,
}

70
backend/db/schema.sql Normal file
View File

@@ -0,0 +1,70 @@
-- AgentCoord 数据库表结构
-- 基于 DATABASE_DESIGN.md 设计
-- 执行方式: psql -U postgres -d agentcoord -f schema.sql
-- =============================================================================
-- 表1: multi_agent_tasks (多智能体任务记录)
-- 状态枚举: pending/planning/generating/executing/completed/failed
-- =============================================================================
CREATE TABLE IF NOT EXISTS multi_agent_tasks (
task_id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
query TEXT NOT NULL,
agents_info JSONB NOT NULL,
task_outline JSONB,
assigned_agents JSONB,
agent_scores JSONB,
result JSONB,
status VARCHAR(20) DEFAULT 'pending',
execution_count INTEGER DEFAULT 0,
generation_id VARCHAR(64),
execution_id VARCHAR(64),
rehearsal_log JSONB,
branches JSONB, -- 任务大纲探索分支数据
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_user_id ON multi_agent_tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_created_at ON multi_agent_tasks(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_status ON multi_agent_tasks(status);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_generation_id ON multi_agent_tasks(generation_id);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_execution_id ON multi_agent_tasks(execution_id);
-- =============================================================================
-- 表2: user_agents (用户保存的智能体配置) - 可选表
-- =============================================================================
CREATE TABLE IF NOT EXISTS user_agents (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
agent_name VARCHAR(100) NOT NULL,
agent_config JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_user_agents_user_id ON user_agents(user_id);
-- =============================================================================
-- 更新时间触发器函数
-- =============================================================================
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
-- 为 multi_agent_tasks 表创建触发器
CREATE TRIGGER update_multi_agent_tasks_updated_at
BEFORE UPDATE ON multi_agent_tasks
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
DO $$
BEGIN
RAISE NOTICE '✅ PostgreSQL 数据库表结构创建完成!';
RAISE NOTICE '表: multi_agent_tasks (多智能体任务记录)';
RAISE NOTICE '表: user_agents (用户智能体配置)';
END $$;

File diff suppressed because it is too large Load Diff