3555 lines
126 KiB
Python
3555 lines
126 KiB
Python
from flask import Flask, request, jsonify, Response, stream_with_context, send_file
|
||
from flask_cors import CORS
|
||
from flask_socketio import SocketIO, emit, join_room, leave_room
|
||
import json
|
||
from DataProcess import Add_Collaboration_Brief_FrontEnd
|
||
from AgentCoord.RehearsalEngine_V2.ExecutePlan import executePlan
|
||
from AgentCoord.RehearsalEngine_V2.ExecutePlan_Optimized import executePlan_streaming
|
||
from AgentCoord.PlanEngine.basePlan_Generator import generate_basePlan
|
||
from AgentCoord.PlanEngine.fill_stepTask import fill_stepTask
|
||
from AgentCoord.PlanEngine.fill_stepTask_TaskProcess import (
|
||
fill_stepTask_TaskProcess,
|
||
)
|
||
from AgentCoord.PlanEngine.branch_PlanOutline import branch_PlanOutline
|
||
from AgentCoord.PlanEngine.branch_TaskProcess import branch_TaskProcess
|
||
from AgentCoord.PlanEngine.AgentSelectModify import (
|
||
AgentSelectModify_init,
|
||
AgentSelectModify_addAspect,
|
||
)
|
||
import os
|
||
import yaml
|
||
import argparse
|
||
import uuid
|
||
import copy
|
||
import base64
|
||
from typing import List, Dict, Optional
|
||
from datetime import datetime, timezone, timedelta
|
||
|
||
# 数据库模块导入
|
||
from db import (
|
||
get_db_context,
|
||
MultiAgentTaskCRUD,
|
||
UserAgentCRUD,
|
||
ExportRecordCRUD,
|
||
PlanShareCRUD,
|
||
TaskStatus,
|
||
text,
|
||
)
|
||
|
||
# 导出模块导入
|
||
from AgentCoord.Export import ExportFactory
|
||
|
||
# initialize global variables
|
||
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
||
try:
|
||
with open(yaml_file, "r", encoding="utf-8") as file:
|
||
yaml_data = yaml.safe_load(file)
|
||
except Exception:
|
||
yaml_data = {}
|
||
USE_CACHE: bool = os.getenv("USE_CACHE")
|
||
if USE_CACHE is None:
|
||
USE_CACHE = yaml_data.get("USE_CACHE", False)
|
||
else:
|
||
USE_CACHE = USE_CACHE.lower() in ["true", "1", "yes"]
|
||
AgentBoard = None
|
||
AgentProfile_Dict = {}
|
||
Request_Cache: dict[str, str] = {}
|
||
app = Flask(__name__)
|
||
app.config['SECRET_KEY'] = 'agentcoord-secret-key'
|
||
CORS(app) # 启用 CORS 支持
|
||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading')
|
||
|
||
# 配置静态文件服务(用于导出文件访问)
|
||
EXPORT_DIR = os.path.join(os.getcwd(), "uploads", "exports")
|
||
|
||
@app.route('/uploads/<path:filename>', methods=['GET'])
|
||
def serve_export_file(filename):
|
||
"""服务导出文件(静态文件访问)"""
|
||
from flask import send_from_directory
|
||
return send_from_directory('uploads', filename)
|
||
|
||
def truncate_rehearsal_log(RehearsalLog: List, restart_from_step_index: int) -> List:
|
||
"""
|
||
截断 RehearsalLog,只保留指定索引之前的步骤结果
|
||
|
||
Args:
|
||
RehearsalLog: 原始日志列表
|
||
restart_from_step_index: 重新执行的起始步骤索引(例如:1 表示保留步骤0,从步骤1重新执行)
|
||
|
||
Returns:
|
||
截断后的 RehearsalLog
|
||
|
||
示例:
|
||
restart_from_step_index = 1
|
||
RehearsalLog = [step0, object0, step1, object1, step2, object2]
|
||
返回 = [step0, object0] # 只保留步骤0的结果
|
||
"""
|
||
truncated_log = []
|
||
step_count = 0
|
||
|
||
for logNode in RehearsalLog:
|
||
if logNode.get("LogNodeType") == "step":
|
||
# 只保留 restart_from_step_index 之前的步骤
|
||
if step_count < restart_from_step_index:
|
||
truncated_log.append(logNode)
|
||
step_count += 1
|
||
elif logNode.get("LogNodeType") == "object":
|
||
# object 节点:如果对应的 step 在保留范围内,保留它
|
||
# 策略:保留所有在截断点之前的 object
|
||
if step_count <= restart_from_step_index:
|
||
truncated_log.append(logNode)
|
||
|
||
return truncated_log
|
||
|
||
|
||
def convert_score_table_to_front_format(scoreTable: dict) -> dict:
|
||
"""
|
||
将后端 scoreTable 格式转换为前端期望的格式
|
||
|
||
后端格式: {aspect: {agent: {Score, Reason}}}
|
||
前端格式: {aspectList, agentScores: {agent: {aspect: {score, reason}}}}
|
||
|
||
Args:
|
||
scoreTable: 后端评分表
|
||
|
||
Returns:
|
||
转换后的数据,包含 aspectList 和 agentScores
|
||
"""
|
||
aspect_list = list(scoreTable.keys())
|
||
agent_scores = {}
|
||
|
||
for aspect, agents_data in scoreTable.items():
|
||
for agent_name, score_data in agents_data.items():
|
||
if agent_name not in agent_scores:
|
||
agent_scores[agent_name] = {}
|
||
agent_scores[agent_name][aspect] = {
|
||
"score": score_data.get("Score", score_data.get("score", 0)),
|
||
"reason": score_data.get("Reason", score_data.get("reason", ""))
|
||
}
|
||
|
||
return {
|
||
"aspectList": aspect_list,
|
||
"agentScores": agent_scores
|
||
}
|
||
|
||
|
||
def init():
|
||
global AgentBoard, AgentProfile_Dict, Request_Cache
|
||
|
||
# Load Request Cache
|
||
try:
|
||
with open(
|
||
os.path.join(os.getcwd(), "RequestCache", "Request_Cache.json"), "r"
|
||
) as json_file:
|
||
Request_Cache = json.load(json_file)
|
||
except Exception as e:
|
||
Request_Cache = {}
|
||
|
||
# Load Agent Board
|
||
try:
|
||
with open(
|
||
os.path.join(os.getcwd(), "AgentRepo", "agentBoard_v1.json"), "r", encoding="utf-8"
|
||
) as json_file:
|
||
AgentBoard = json.load(json_file)
|
||
|
||
# Build AgentProfile_Dict
|
||
AgentProfile_Dict = {}
|
||
for item in AgentBoard:
|
||
name = item["Name"]
|
||
profile = item["Profile"]
|
||
AgentProfile_Dict[name] = profile
|
||
|
||
except Exception as e:
|
||
AgentBoard = []
|
||
AgentProfile_Dict = {}
|
||
|
||
|
||
# ==================== WebSocket 连接管理 ====================
|
||
@socketio.on('connect')
|
||
def handle_connect():
|
||
"""客户端连接"""
|
||
emit('connected', {'sid': request.sid, 'message': 'WebSocket连接成功'})
|
||
|
||
|
||
@socketio.on('disconnect')
|
||
def handle_disconnect():
|
||
"""客户端断开连接"""
|
||
pass
|
||
|
||
|
||
@socketio.on('ping')
|
||
def handle_ping():
|
||
"""心跳检测"""
|
||
emit('pong')
|
||
|
||
|
||
# ==================== WebSocket 事件处理 ====================
|
||
# 注:以下为WebSocket版本的接口,与REST API并存
|
||
# 逐步迁移核心接口到WebSocket
|
||
|
||
|
||
@socketio.on('execute_plan_optimized')
|
||
def handle_execute_plan_optimized_ws(data):
|
||
"""
|
||
WebSocket版本:优化版流式执行计划
|
||
支持步骤级流式 + 动作级智能并行 + 动态追加步骤 + 从指定步骤重新执行
|
||
执行完成后保存结果到数据库
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "execute_plan_optimized",
|
||
"data": {
|
||
"task_id": "task-id", # 可选:如果需要保存到数据库
|
||
"plan": {...},
|
||
"num_StepToRun": null,
|
||
"RehearsalLog": [],
|
||
"enable_dynamic": true,
|
||
"restart_from_step_index": 1
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
plan = incoming_data.get("plan")
|
||
num_StepToRun = incoming_data.get("num_StepToRun")
|
||
RehearsalLog = incoming_data.get("RehearsalLog", [])
|
||
enable_dynamic = incoming_data.get("enable_dynamic", False)
|
||
restart_from_step_index = incoming_data.get("restart_from_step_index")
|
||
task_id = incoming_data.get("task_id")
|
||
|
||
# 执行开始前更新状态为 EXECUTING
|
||
with get_db_context() as db:
|
||
MultiAgentTaskCRUD.update_status(db, task_id, TaskStatus.EXECUTING)
|
||
|
||
print(f"[WS] 开始执行计划: goal={plan.get('General Goal', '')}, dynamic={enable_dynamic}")
|
||
|
||
# 收集每个步骤使用的 agent(用于写入 assigned_agents 字段)
|
||
assigned_agents_collection = {}
|
||
|
||
def collect_assigned_agents_from_chunk(chunk: str):
|
||
"""从 chunk 中提取 assigned_agents 信息"""
|
||
try:
|
||
import json
|
||
event_str = chunk.replace('data: ', '').replace('\n\n', '').strip()
|
||
if not event_str:
|
||
return
|
||
event = json.loads(event_str)
|
||
if event.get('type') == 'step_complete':
|
||
step_assigned_agents = event.get('assigned_agents', {})
|
||
if step_assigned_agents:
|
||
assigned_agents_collection.update(step_assigned_agents)
|
||
except Exception as e:
|
||
pass
|
||
|
||
# 如果指定了重新执行起始步骤,截断 RehearsalLog
|
||
if restart_from_step_index is not None:
|
||
RehearsalLog = truncate_rehearsal_log(RehearsalLog, restart_from_step_index)
|
||
|
||
# 如果前端传入了execution_id,使用前端的;否则生成新的
|
||
execution_id = incoming_data.get("execution_id")
|
||
if not execution_id:
|
||
import time
|
||
execution_id = f"{plan.get('General Goal', '').replace(' ', '_')}_{int(time.time() * 1000)}"
|
||
|
||
if enable_dynamic:
|
||
# 动态模式:使用executePlan_streaming_dynamic
|
||
from AgentCoord.RehearsalEngine_V2.ExecutePlan_Optimized import executePlan_streaming_dynamic
|
||
|
||
# 发送执行ID(确认使用的ID)
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'execution_started',
|
||
'execution_id': execution_id,
|
||
'message': '执行已启动,支持动态追加步骤'
|
||
})
|
||
|
||
for chunk in executePlan_streaming_dynamic(
|
||
plan=plan,
|
||
num_StepToRun=num_StepToRun,
|
||
RehearsalLog=RehearsalLog,
|
||
AgentProfile_Dict=AgentProfile_Dict,
|
||
execution_id=execution_id
|
||
):
|
||
# 收集 assigned_agents
|
||
collect_assigned_agents_from_chunk(chunk)
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'data': chunk.replace('data: ', '').replace('\n\n', '')
|
||
})
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'complete',
|
||
'data': None
|
||
})
|
||
|
||
else:
|
||
# 非动态模式:使用原有方式
|
||
for chunk in executePlan_streaming(
|
||
plan=plan,
|
||
num_StepToRun=num_StepToRun,
|
||
RehearsalLog=RehearsalLog,
|
||
AgentProfile_Dict=AgentProfile_Dict,
|
||
execution_id=execution_id
|
||
):
|
||
# 收集 assigned_agents
|
||
collect_assigned_agents_from_chunk(chunk)
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'data': chunk.replace('data: ', '').replace('\n\n', '')
|
||
})
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'complete',
|
||
'data': None
|
||
})
|
||
|
||
print(f"[WS] 执行计划完成: execution_id={execution_id}")
|
||
|
||
# 执行完成后保存到数据库
|
||
if task_id:
|
||
try:
|
||
with get_db_context() as db:
|
||
# 计算已完成的步骤数
|
||
completed_steps_count = len([node for node in RehearsalLog if node.get("LogNodeType") == "step"])
|
||
# 获取计划总步骤数
|
||
plan_steps_count = len(plan.get("Collaboration Process", [])) if plan else 0
|
||
|
||
# 更新执行ID(始终保存)
|
||
MultiAgentTaskCRUD.update_execution_id(db, task_id, execution_id)
|
||
|
||
# 更新执行次数(始终保存)
|
||
MultiAgentTaskCRUD.increment_execution_count(db, task_id)
|
||
|
||
# 判断是否完整执行:已完成步骤数 >= 计划步骤数
|
||
is_complete_execution = completed_steps_count >= plan_steps_count and plan_steps_count > 0
|
||
|
||
if is_complete_execution:
|
||
# 完整执行:保存所有执行数据
|
||
MultiAgentTaskCRUD.update_rehearsal_log(db, task_id, RehearsalLog)
|
||
print(f"[execute_plan_optimized] 完整执行完成,已保存 RehearsalLog({completed_steps_count} 个步骤),task_id={task_id}")
|
||
|
||
# 保存执行结果(覆盖模式)
|
||
step_results = [node for node in RehearsalLog if node.get("LogNodeType") == "step"]
|
||
MultiAgentTaskCRUD.update_result(db, task_id, step_results)
|
||
print(f"[execute_plan_optimized] 完整执行完成,已保存 result({len(step_results)} 个步骤结果),task_id={task_id}")
|
||
|
||
# 更新状态为完成
|
||
MultiAgentTaskCRUD.update_status(db, task_id, TaskStatus.COMPLETED)
|
||
else:
|
||
# 未完整执行(用户停止):不保存执行数据,只更新状态为 STOPPED
|
||
MultiAgentTaskCRUD.update_status(db, task_id, TaskStatus.STOPPED)
|
||
print(f"[execute_plan_optimized] 用户停止执行,跳过保存执行数据,已完成 {completed_steps_count}/{plan_steps_count} 步骤,task_id={task_id}")
|
||
|
||
# # 任务大纲(用户可能编辑了)仍然保存
|
||
# # 注释原因:执行任务时不保存 task_outline,避免覆盖导致步骤 ID 变化与 agent_scores 不匹配
|
||
# if plan:
|
||
# MultiAgentTaskCRUD.update_task_outline(db, task_id, plan)
|
||
# print(f"[execute_plan_optimized] 已保存 task_outline 到数据库,task_id={task_id}")
|
||
|
||
# # 保存 assigned_agents(每个步骤使用的 agent)
|
||
# # 注释原因:assigned_agents 只在生成阶段由用户手动选择写入,执行时不覆盖
|
||
# if assigned_agents_collection:
|
||
# MultiAgentTaskCRUD.update_assigned_agents(db, task_id, assigned_agents_collection)
|
||
# print(f"[execute_plan_optimized] 已保存 assigned_agents 到数据库,task_id={task_id}")
|
||
except Exception:
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
except Exception as e:
|
||
# 发送错误信息
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('add_steps_to_execution')
|
||
def handle_add_steps_to_execution(data):
|
||
"""
|
||
WebSocket版本:向正在执行的任务追加新步骤
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "add_steps_to_execution",
|
||
"data": {
|
||
"execution_id": "execution_id",
|
||
"new_steps": [...]
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
from AgentCoord.RehearsalEngine_V2.dynamic_execution_manager import dynamic_execution_manager
|
||
|
||
execution_id = incoming_data.get('execution_id')
|
||
new_steps = incoming_data.get('new_steps', [])
|
||
|
||
if not execution_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少execution_id参数'
|
||
})
|
||
return
|
||
|
||
# 追加新步骤到执行队列
|
||
added_count = dynamic_execution_manager.add_steps(execution_id, new_steps)
|
||
|
||
if added_count > 0:
|
||
print(f"[WS] 成功追加 {added_count} 个步骤: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
'message': f'成功追加 {added_count} 个步骤',
|
||
'added_count': added_count
|
||
}
|
||
})
|
||
else:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '执行ID不存在或已结束'
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('generate_base_plan')
|
||
def handle_generate_base_plan_ws(data):
|
||
"""
|
||
WebSocket版本:生成基础计划(支持流式/分步返回)
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "generate_base_plan",
|
||
"data": {
|
||
"General Goal": "...",
|
||
"Initial Input Object": [...]
|
||
}
|
||
}
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "generating_outline", "message": "正在生成计划大纲..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "processing_steps", "step": 1, "total": 3, "message": "正在处理步骤 1/3..."}
|
||
- response: {"id": request_id, "status": "success", "data": basePlan}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
# 检查缓存
|
||
requestIdentifier = str((
|
||
"/generate_basePlan",
|
||
incoming_data.get("General Goal"),
|
||
incoming_data.get("Initial Input Object"),
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 阶段1:生成计划大纲
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'generating_outline',
|
||
'message': '📋 正在生成计划大纲...'
|
||
})
|
||
|
||
from AgentCoord.PlanEngine.planOutline_Generator import generate_PlanOutline
|
||
PlanOutline = generate_PlanOutline(
|
||
InitialObject_List=incoming_data.get("Initial Input Object"),
|
||
General_Goal=incoming_data.get("General Goal")
|
||
)
|
||
|
||
# 阶段2:构建基础计划(逐步添加步骤)
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'building_plan',
|
||
'total_steps': len(PlanOutline),
|
||
'message': f'🔨 正在构建计划,共 {len(PlanOutline)} 个步骤...'
|
||
})
|
||
|
||
basePlan = {
|
||
"General Goal": incoming_data.get("General Goal"),
|
||
"Initial Input Object": incoming_data.get("Initial Input Object"),
|
||
"Collaboration Process": []
|
||
}
|
||
|
||
for idx, stepItem in enumerate(PlanOutline, 1):
|
||
# 添加智能体选择和任务流程字段
|
||
stepItem["AgentSelection"] = []
|
||
stepItem["TaskProcess"] = []
|
||
stepItem["Collaboration_Brief_frontEnd"] = {
|
||
"template": "",
|
||
"data": {}
|
||
}
|
||
basePlan["Collaboration Process"].append(stepItem)
|
||
|
||
# 发送进度更新
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'adding_step',
|
||
'step': idx,
|
||
'total': len(PlanOutline),
|
||
'step_name': stepItem.get("StepName", ""),
|
||
'message': f'✅ 已添加步骤 {idx}/{len(PlanOutline)}: {stepItem.get("StepName", "")}'
|
||
})
|
||
|
||
# 阶段3:处理渲染规范
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'rendering',
|
||
'message': '🎨 正在处理渲染规范...'
|
||
})
|
||
|
||
basePlan_withRenderSpec = Add_Collaboration_Brief_FrontEnd(basePlan)
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = basePlan_withRenderSpec
|
||
|
||
# 保存到数据库
|
||
user_id = incoming_data.get("user_id")
|
||
task_id = incoming_data.get("task_id")
|
||
generation_id = str(uuid.uuid4())
|
||
|
||
with get_db_context() as db:
|
||
# 检查是否已存在任务
|
||
existing_task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
|
||
if existing_task:
|
||
# 更新现有任务
|
||
MultiAgentTaskCRUD.update_task_outline(db, task_id, basePlan_withRenderSpec)
|
||
MultiAgentTaskCRUD.update_generation_id(db, task_id, generation_id)
|
||
MultiAgentTaskCRUD.update_status(db, task_id, TaskStatus.GENERATING)
|
||
else:
|
||
# 创建新任务(如果 task_id 为 None,会自动生成新的)
|
||
created_task = MultiAgentTaskCRUD.create(
|
||
db=db,
|
||
task_id=task_id, # 如果为 None 则自动生成新的
|
||
user_id=user_id,
|
||
query=incoming_data.get("General Goal", ""),
|
||
agents_info=AgentBoard or [],
|
||
task_outline=basePlan_withRenderSpec,
|
||
assigned_agents=None,
|
||
agent_scores=None,
|
||
result=None,
|
||
)
|
||
# 使用实际创建的任务 ID(可能是新生成的)
|
||
task_id = created_task.task_id
|
||
MultiAgentTaskCRUD.update_generation_id(db, task_id, generation_id)
|
||
MultiAgentTaskCRUD.update_status(db, task_id, TaskStatus.GENERATING)
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': '✅ 计划生成完成'
|
||
})
|
||
|
||
# 返回最终结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"task_id": task_id,
|
||
"generation_id": generation_id,
|
||
"basePlan": basePlan_withRenderSpec
|
||
}
|
||
})
|
||
|
||
except ValueError as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f"An unexpected error occurred: {str(e)}"
|
||
})
|
||
|
||
|
||
@socketio.on('fill_step_task')
|
||
def handle_fill_step_task_ws(data):
|
||
"""
|
||
WebSocket版本:填充步骤任务(支持流式/分步返回)
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "starting", "message": "开始填充步骤任务..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "agent_selection", "message": "正在生成智能体选择..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "task_process", "message": "正在生成任务流程..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "complete", "message": "任务填充完成"}
|
||
- response: {"id": request_id, "status": "success", "data": filled_stepTask}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get("task_id")
|
||
|
||
try:
|
||
print(f"[WS] 开始处理 fill_step_task: request_id={request_id}, task_id={task_id}")
|
||
|
||
# 检查缓存
|
||
requestIdentifier = str((
|
||
"/fill_stepTask",
|
||
incoming_data.get("General Goal"),
|
||
incoming_data.get("stepTask"),
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 开始处理
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'starting',
|
||
'message': f'🚀 开始填充步骤任务: {incoming_data.get("stepTask", {}).get("StepName", "")}'
|
||
})
|
||
|
||
# 阶段1:生成智能体选择
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'agent_selection',
|
||
'message': '👥 正在生成智能体选择...'
|
||
})
|
||
|
||
from AgentCoord.PlanEngine.AgentSelection_Generator import generate_AgentSelection
|
||
stepTask = incoming_data.get("stepTask")
|
||
Current_Task = {
|
||
"TaskName": stepTask.get("StepName"),
|
||
"InputObject_List": stepTask.get("InputObject_List"),
|
||
"OutputObject": stepTask.get("OutputObject"),
|
||
"TaskContent": stepTask.get("TaskContent"),
|
||
}
|
||
# 调整字段顺序:确保 Name 在 Icon 前面,避免 LLM 把 Icon 当名字用
|
||
agent_board_for_llm = []
|
||
for agent in AgentBoard:
|
||
# 按固定顺序重组:Name, Profile, Icon, Classification
|
||
new_agent = {}
|
||
if 'Name' in agent:
|
||
new_agent['Name'] = agent['Name']
|
||
if 'Profile' in agent:
|
||
new_agent['Profile'] = agent['Profile']
|
||
if 'Icon' in agent:
|
||
new_agent['Icon'] = agent['Icon']
|
||
if 'Classification' in agent:
|
||
new_agent['Classification'] = agent['Classification']
|
||
# 保留其他字段
|
||
for k, v in agent.items():
|
||
if k not in new_agent:
|
||
new_agent[k] = v
|
||
agent_board_for_llm.append(new_agent)
|
||
AgentSelection = generate_AgentSelection(
|
||
General_Goal=incoming_data.get("General Goal"),
|
||
Current_Task=Current_Task,
|
||
Agent_Board=agent_board_for_llm,
|
||
)
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'agent_selection_done',
|
||
'message': f'✅ 智能体选择完成: {", ".join(AgentSelection)}'
|
||
})
|
||
|
||
# 阶段2:生成任务流程
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'task_process',
|
||
'message': '📝 正在生成任务流程...'
|
||
})
|
||
|
||
import AgentCoord.util as util
|
||
from AgentCoord.PlanEngine.taskProcess_Generator import generate_TaskProcess
|
||
Current_Task_Description = {
|
||
"TaskName": stepTask.get("StepName"),
|
||
"AgentInvolved": [
|
||
{"Name": name, "Profile": AgentProfile_Dict[name]}
|
||
for name in AgentSelection
|
||
],
|
||
"InputObject_List": stepTask.get("InputObject_List"),
|
||
"OutputObject": stepTask.get("OutputObject"),
|
||
"CurrentTaskDescription": util.generate_template_sentence_for_CollaborationBrief(
|
||
stepTask.get("InputObject_List"),
|
||
stepTask.get("OutputObject"),
|
||
AgentSelection,
|
||
stepTask.get("TaskContent"),
|
||
),
|
||
}
|
||
TaskProcess = generate_TaskProcess(
|
||
General_Goal=incoming_data.get("General Goal"),
|
||
Current_Task_Description=Current_Task_Description,
|
||
)
|
||
|
||
# 构建结果
|
||
stepTask["AgentSelection"] = AgentSelection
|
||
stepTask["TaskProcess"] = TaskProcess
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'task_process_done',
|
||
'message': f'✅ 任务流程生成完成,共 {len(TaskProcess)} 个动作'
|
||
})
|
||
|
||
# 阶段3:处理渲染规范
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'rendering',
|
||
'message': '🎨 正在处理渲染规范...'
|
||
})
|
||
|
||
filled_stepTask = Add_Collaboration_Brief_FrontEnd(stepTask)
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = filled_stepTask
|
||
|
||
# 保存到数据库 - 更新任务大纲和 assigned_agents
|
||
task_id = incoming_data.get("task_id")
|
||
if task_id:
|
||
with get_db_context() as db:
|
||
# 获取现有任务,更新步骤
|
||
existing_task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
if existing_task and existing_task.task_outline:
|
||
task_outline = existing_task.task_outline.copy()
|
||
collaboration_process = task_outline.get("Collaboration Process", [])
|
||
|
||
# 获取原始请求中的步骤ID
|
||
step_id = stepTask.get("Id") or stepTask.get("id")
|
||
# 更新对应步骤 - 使用 StepName 匹配
|
||
step_name = stepTask.get("StepName")
|
||
for i, step in enumerate(collaboration_process):
|
||
if step.get("StepName") == step_name:
|
||
collaboration_process[i] = filled_stepTask
|
||
# 如果原始步骤没有ID,从更新后的步骤获取
|
||
if not step_id:
|
||
step_id = filled_stepTask.get("Id") or filled_stepTask.get("id")
|
||
break
|
||
|
||
task_outline["Collaboration Process"] = collaboration_process
|
||
|
||
# 直接用SQL更新,绕过ORM事务问题
|
||
import json
|
||
db.execute(
|
||
text("UPDATE multi_agent_tasks SET task_outline = :outline WHERE task_id = :id"),
|
||
{"outline": json.dumps(task_outline), "id": task_id}
|
||
)
|
||
db.commit()
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': '✅ 任务填充完成'
|
||
})
|
||
|
||
# 返回结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"task_id": task_id,
|
||
"filled_stepTask": filled_stepTask
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"[WS] fill_step_task 处理失败: {request_id}, error={str(e)}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('fill_step_task_process')
|
||
def handle_fill_step_task_process_ws(data):
|
||
"""
|
||
WebSocket版本:填充步骤任务流程(支持流式/分步返回)
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "starting", "message": "开始生成任务流程..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "generating", "message": "正在生成任务流程..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "complete", "message": "任务流程生成完成"}
|
||
- response: {"id": request_id, "status": "success", "data": filled_stepTask}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
# 检查缓存
|
||
requestIdentifier = str((
|
||
"/fill_stepTask_TaskProcess",
|
||
incoming_data.get("General Goal"),
|
||
incoming_data.get("stepTask_lackTaskProcess"),
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 开始处理
|
||
stepTask = incoming_data.get("stepTask_lackTaskProcess")
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'starting',
|
||
'message': f'🚀 开始生成任务流程: {stepTask.get("StepName", "")}'
|
||
})
|
||
|
||
# 生成任务流程
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'generating',
|
||
'message': '📝 正在生成任务流程...'
|
||
})
|
||
|
||
filled_stepTask = fill_stepTask_TaskProcess(
|
||
General_Goal=incoming_data.get("General Goal"),
|
||
stepTask=stepTask,
|
||
AgentProfile_Dict=AgentProfile_Dict,
|
||
)
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'generated',
|
||
'message': f'✅ 任务流程生成完成,共 {len(filled_stepTask.get("TaskProcess", []))} 个动作'
|
||
})
|
||
|
||
# 处理渲染规范
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'rendering',
|
||
'message': '🎨 正在处理渲染规范...'
|
||
})
|
||
|
||
filled_stepTask = Add_Collaboration_Brief_FrontEnd(filled_stepTask)
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = filled_stepTask
|
||
|
||
# 🆕 保存 TaskProcess 数据到 assigned_agents
|
||
task_id = incoming_data.get("task_id")
|
||
agents = incoming_data.get("agents", [])
|
||
if task_id and agents:
|
||
with get_db_context() as db:
|
||
# 获取步骤ID
|
||
step_id = stepTask.get("Id") or stepTask.get("id")
|
||
if step_id:
|
||
# 获取现有 assigned_agents,确保是 dict 类型
|
||
task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
raw_assigned = task.assigned_agents
|
||
existing_assigned = raw_assigned if isinstance(raw_assigned, dict) else {}
|
||
|
||
# 确保步骤数据结构存在
|
||
if step_id not in existing_assigned:
|
||
existing_assigned[step_id] = {}
|
||
if "agent_combinations" not in existing_assigned[step_id]:
|
||
existing_assigned[step_id]["agent_combinations"] = {}
|
||
|
||
# 生成 agentGroupKey(排序后的JSON)
|
||
agent_group_key = json.dumps(sorted(agents))
|
||
|
||
# 保存 TaskProcess 和 brief 数据
|
||
existing_assigned[step_id]["agent_combinations"][agent_group_key] = {
|
||
"process": filled_stepTask.get("TaskProcess", []),
|
||
"brief": filled_stepTask.get("Collaboration_Brief_frontEnd", {})
|
||
}
|
||
|
||
# 更新数据库
|
||
db.execute(
|
||
text("UPDATE multi_agent_tasks SET assigned_agents = :assigned WHERE task_id = :id"),
|
||
{"assigned": json.dumps(existing_assigned), "id": task_id}
|
||
)
|
||
db.commit()
|
||
print(f"[fill_step_task_process] 已保存 agent_combinations: task_id={task_id}, step_id={step_id}, agents={agents}")
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': '✅ 任务流程生成完成'
|
||
})
|
||
|
||
# 返回结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': filled_stepTask
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('branch_plan_outline')
|
||
def handle_branch_plan_outline_ws(data):
|
||
"""
|
||
WebSocket版本:分支任务大纲(支持流式/分步返回)
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "starting", "branch": 1, "total": 3, "message": "正在生成分支 1/3..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "complete", "message": "分支大纲生成完成"}
|
||
- response: {"id": request_id, "status": "success", "data": branchList}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
# 检查缓存
|
||
requestIdentifier = str((
|
||
"/branch_PlanOutline",
|
||
incoming_data.get("branch_Number"),
|
||
incoming_data.get("Modification_Requirement"),
|
||
incoming_data.get("Existing_Steps"),
|
||
incoming_data.get("Baseline_Completion"),
|
||
incoming_data.get("Initial Input Object"),
|
||
incoming_data.get("General Goal"),
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 开始处理
|
||
branch_Number = incoming_data.get("branch_Number")
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'starting',
|
||
'total_branches': branch_Number,
|
||
'message': f'🚀 开始生成分支大纲,共 {branch_Number} 个分支...'
|
||
})
|
||
|
||
# 生成大纲分支(逐步生成)
|
||
from AgentCoord.util.converter import read_LLM_Completion
|
||
from AgentCoord.PlanEngine.branch_PlanOutline import JSON_PLAN_OUTLINE_BRANCHING
|
||
import json
|
||
|
||
prompt = f"""
|
||
## Instruction
|
||
Based on "Existing Steps", your task is to comeplete the "Remaining Steps" for the plan for "General Goal".
|
||
Note: "Modification Requirement" specifies how to modify the "Baseline Completion" for a better/alternative solution.
|
||
|
||
**IMPORTANT LANGUAGE REQUIREMENT: You must respond in Chinese (中文) for all content, including StepName, TaskContent, and OutputObject fields.**
|
||
|
||
## General Goal (Specify the general goal for the plan)
|
||
{incoming_data.get("General Goal")}
|
||
|
||
## Initial Key Object List (Specify the list of initial key objects available for use as the input object of a Step)
|
||
{incoming_data.get("Initial Input Object")}
|
||
|
||
## Existing Steps
|
||
{json.dumps(incoming_data.get("Existing_Steps"), indent=4)}
|
||
|
||
## Baseline Completion
|
||
{json.dumps(incoming_data.get("Baseline_Completion"), indent=4)}
|
||
|
||
## Modification Requirement
|
||
{incoming_data.get("Modification_Requirement")}
|
||
"""
|
||
|
||
branch_List = []
|
||
for i in range(branch_Number):
|
||
# 发送进度更新
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'generating_branch',
|
||
'branch': i + 1,
|
||
'total': branch_Number,
|
||
'message': f'🌿 正在生成分支大纲 {i+1}/{branch_Number}...'
|
||
})
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": f" The JSON object must use the schema: {json.dumps(JSON_PLAN_OUTLINE_BRANCHING.model_json_schema(), indent=2)}",
|
||
},
|
||
{"role": "system", "content": prompt},
|
||
]
|
||
Remaining_Steps = read_LLM_Completion(messages, useGroq=False)["Remaining Steps"]
|
||
branch_List.append(Remaining_Steps)
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'branch_done',
|
||
'branch': i + 1,
|
||
'total': branch_Number,
|
||
'steps_count': len(Remaining_Steps),
|
||
'message': f'✅ 分支 {i+1}/{branch_Number} 生成完成,包含 {len(Remaining_Steps)} 个步骤'
|
||
})
|
||
|
||
# 处理渲染规范
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'rendering',
|
||
'message': '🎨 正在处理渲染规范...'
|
||
})
|
||
|
||
branchList = Add_Collaboration_Brief_FrontEnd(branch_List)
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = branchList
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': f'✅ 分支大纲生成完成,共 {branch_Number} 个分支'
|
||
})
|
||
|
||
# 返回结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': branchList
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('branch_task_process')
|
||
def handle_branch_task_process_ws(data):
|
||
"""
|
||
WebSocket版本:分支任务流程(支持流式/分步返回)
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "starting", "branch": 1, "total": 3, "message": "正在生成分支任务流程 1/3..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "complete", "message": "分支任务流程生成完成"}
|
||
- response: {"id": request_id, "status": "success", "data": branchList}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
# 检查缓存
|
||
requestIdentifier = str((
|
||
"/branch_TaskProcess",
|
||
incoming_data.get("branch_Number"),
|
||
incoming_data.get("Modification_Requirement"),
|
||
incoming_data.get("Existing_Steps"),
|
||
incoming_data.get("Baseline_Completion"),
|
||
incoming_data.get("stepTaskExisting"),
|
||
incoming_data.get("General Goal"),
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 开始处理
|
||
branch_Number = incoming_data.get("branch_Number")
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'starting',
|
||
'total_branches': branch_Number,
|
||
'message': f'🚀 开始生成分支任务流程,共 {branch_Number} 个分支...'
|
||
})
|
||
|
||
# 生成任务流程分支(逐步生成)
|
||
from AgentCoord.util.converter import read_LLM_Completion
|
||
from AgentCoord.PlanEngine.branch_TaskProcess import (
|
||
JSON_TASK_PROCESS_BRANCHING,
|
||
ACT_SET,
|
||
PROMPT_TASK_PROCESS_BRANCHING
|
||
)
|
||
import AgentCoord.util as util
|
||
import json
|
||
|
||
stepTaskExisting = incoming_data.get("stepTaskExisting")
|
||
Current_Task_Description = {
|
||
"TaskName": stepTaskExisting.get("StepName"),
|
||
"AgentInvolved": [
|
||
{"Name": name, "Profile": AgentProfile_Dict[name]}
|
||
for name in stepTaskExisting.get("AgentSelection", [])
|
||
],
|
||
"InputObject_List": stepTaskExisting.get("InputObject_List"),
|
||
"OutputObject": stepTaskExisting.get("OutputObject"),
|
||
"CurrentTaskDescription": util.generate_template_sentence_for_CollaborationBrief(
|
||
stepTaskExisting.get("InputObject_List"),
|
||
stepTaskExisting.get("OutputObject"),
|
||
stepTaskExisting.get("AgentSelection"),
|
||
stepTaskExisting.get("TaskContent"),
|
||
),
|
||
}
|
||
|
||
prompt = PROMPT_TASK_PROCESS_BRANCHING.format(
|
||
Modification_Requirement=incoming_data.get("Modification_Requirement"),
|
||
Current_Task_Description=json.dumps(Current_Task_Description, indent=4),
|
||
Existing_Steps=json.dumps(incoming_data.get("Existing_Steps"), indent=4),
|
||
Baseline_Completion=json.dumps(incoming_data.get("Baseline_Completion"), indent=4),
|
||
General_Goal=incoming_data.get("General Goal"),
|
||
Act_Set=ACT_SET,
|
||
)
|
||
|
||
branch_List = []
|
||
for i in range(branch_Number):
|
||
# 发送进度更新
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'generating_branch',
|
||
'branch': i + 1,
|
||
'total': branch_Number,
|
||
'message': f'🌿 正在生成分支任务流程 {i+1}/{branch_Number}...'
|
||
})
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": f" The JSON object must use the schema: {json.dumps(JSON_TASK_PROCESS_BRANCHING.model_json_schema(), indent=2)}",
|
||
},
|
||
{"role": "system", "content": prompt},
|
||
]
|
||
Remaining_Steps = read_LLM_Completion(messages, useGroq=False)["Remaining Steps"]
|
||
branch_List.append(Remaining_Steps)
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'branch_done',
|
||
'branch': i + 1,
|
||
'total': branch_Number,
|
||
'actions_count': len(Remaining_Steps),
|
||
'message': f'✅ 分支 {i+1}/{branch_Number} 生成完成,包含 {len(Remaining_Steps)} 个动作'
|
||
})
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = branch_List
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': f'✅ 分支任务流程生成完成,共 {branch_Number} 个分支'
|
||
})
|
||
|
||
# 返回结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': branch_List
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('agent_select_modify_init')
|
||
def handle_agent_select_modify_init_ws(data):
|
||
"""
|
||
WebSocket版本:智能体选择评分初始化(支持流式/分步返回)
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "starting", "message": "开始生成能力需求..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "requirements", "message": "能力需求: [xxx, yyy, zzz]"}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "scoring", "aspect": 1, "total": 3, "message": "正在评分能力 1/3..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "complete", "message": "智能体评分完成"}
|
||
- response: {"id": request_id, "status": "success", "data": scoreTable}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
# 检查缓存
|
||
requestIdentifier = str((
|
||
"/agentSelectModify_init",
|
||
incoming_data.get("General Goal"),
|
||
incoming_data.get("stepTask"),
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 开始处理
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'starting',
|
||
'message': '🚀 开始生成智能体能力需求...'
|
||
})
|
||
|
||
from AgentCoord.util.converter import read_LLM_Completion
|
||
from AgentCoord.PlanEngine.AgentSelectModify import (
|
||
JSON_ABILITY_REQUIREMENT_GENERATION,
|
||
PROMPT_ABILITY_REQUIREMENT_GENERATION,
|
||
agentAbilityScoring
|
||
)
|
||
import json
|
||
|
||
# 阶段1:生成能力需求列表
|
||
stepTask = incoming_data.get("stepTask")
|
||
Current_Task = {
|
||
"TaskName": stepTask.get("StepName"),
|
||
"InputObject_List": stepTask.get("InputObject_List"),
|
||
"OutputObject": stepTask.get("OutputObject"),
|
||
"TaskContent": stepTask.get("TaskContent"),
|
||
}
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'generating_requirements',
|
||
'message': '📋 正在生成能力需求列表...'
|
||
})
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": f" The JSON object must use the schema: {json.dumps(JSON_ABILITY_REQUIREMENT_GENERATION.model_json_schema(), indent=2)}",
|
||
},
|
||
{
|
||
"role": "system",
|
||
"content": PROMPT_ABILITY_REQUIREMENT_GENERATION.format(
|
||
General_Goal=incoming_data.get("General Goal"),
|
||
Current_Task=json.dumps(Current_Task, indent=4),
|
||
),
|
||
},
|
||
]
|
||
Ability_Requirement_List = read_LLM_Completion(messages)["AbilityRequirement"]
|
||
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'requirements_generated',
|
||
'requirements': Ability_Requirement_List,
|
||
'message': f'✅ 能力需求生成完成: {", ".join(Ability_Requirement_List)}'
|
||
})
|
||
|
||
# 阶段2:为每个能力需求进行智能体评分
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'scoring',
|
||
'total_aspects': len(Ability_Requirement_List),
|
||
'message': f'📊 开始为 {len(Ability_Requirement_List)} 个能力需求评分...'
|
||
})
|
||
|
||
scoreTable = agentAbilityScoring(AgentBoard, Ability_Requirement_List)
|
||
|
||
# 逐步报告评分进度
|
||
for idx, (ability, scores) in enumerate(scoreTable.items(), 1):
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'aspect_scored',
|
||
'aspect': idx,
|
||
'total': len(Ability_Requirement_List),
|
||
'ability': ability,
|
||
'message': f'✅ 能力 "{ability}" 评分完成 ({idx}/{len(Ability_Requirement_List)})'
|
||
})
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = scoreTable
|
||
|
||
# 获取步骤ID(用于 agent_scores 的 key)
|
||
stepTask = incoming_data.get("stepTask", {})
|
||
step_id = stepTask.get("Id") or stepTask.get("id")
|
||
|
||
# 注意:assigned_agents 不在这里写入
|
||
# AgentSelection 只有在 fill_step_task 完成后才会有值
|
||
# assigned_agents 会在 fill_step_task 接口中写入
|
||
|
||
# 保存到数据库(只保存 agent_scores)
|
||
task_id = incoming_data.get("task_id")
|
||
if task_id and step_id:
|
||
with get_db_context() as db:
|
||
# 转换为前端期望格式
|
||
front_format = convert_score_table_to_front_format(scoreTable)
|
||
# 按步骤ID包装评分数据
|
||
step_scores = {step_id: front_format}
|
||
# 只更新智能体评分(跳过 assigned_agents,因为此时 AgentSelection 为空)
|
||
MultiAgentTaskCRUD.update_agent_scores(db, task_id, step_scores)
|
||
print(f"[agent_select_modify_init] 已保存 agent_scores: step_id={step_id}")
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': f'✅ 智能体评分完成,共 {len(Ability_Requirement_List)} 个能力维度'
|
||
})
|
||
|
||
# 返回结果
|
||
# 注意:assigned_agents 不在此处返回,因为此时 AgentSelection 为空
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"task_id": task_id,
|
||
"scoreTable": scoreTable
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('agent_select_modify_add_aspect')
|
||
def handle_agent_select_modify_add_aspect_ws(data):
|
||
"""
|
||
WebSocket版本:添加新的评估维度(支持流式/分步返回)
|
||
|
||
流式事件:
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "starting", "aspect": "新能力", "message": "开始为新能力评分..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "scoring", "message": "正在评分..."}
|
||
- progress: {"id": request_id, "status": "streaming", "stage": "complete", "message": "评分完成"}
|
||
- response: {"id": request_id, "status": "success", "data": scoreTable}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
# 检查缓存
|
||
aspectList = incoming_data.get("aspectList")
|
||
newAspect = aspectList[-1] if aspectList else None
|
||
requestIdentifier = str((
|
||
"/agentSelectModify_addAspect",
|
||
aspectList,
|
||
))
|
||
|
||
if USE_CACHE and requestIdentifier in Request_Cache:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': Request_Cache[requestIdentifier]
|
||
})
|
||
return
|
||
|
||
# 开始处理
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'starting',
|
||
'message': f'🚀 开始为新能力维度评分: {newAspect or "Unknown"}'
|
||
})
|
||
|
||
# 添加新维度并评分
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'scoring',
|
||
'aspect': newAspect,
|
||
'message': f'📊 正在为能力 "{newAspect}" 评分...'
|
||
})
|
||
|
||
scoreTable = AgentSelectModify_addAspect(
|
||
aspectList=aspectList,
|
||
Agent_Board=AgentBoard
|
||
)
|
||
|
||
# 保存到数据库
|
||
task_id = incoming_data.get("task_id")
|
||
if task_id:
|
||
with get_db_context() as db:
|
||
# 获取步骤ID
|
||
stepTask = incoming_data.get("stepTask", {})
|
||
step_id = stepTask.get("Id") or stepTask.get("id") # 使用步骤级ID
|
||
|
||
if step_id:
|
||
# 获取现有评分数据
|
||
task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
existing_scores = task.agent_scores or {}
|
||
existing_step_data = existing_scores.get(step_id, {})
|
||
|
||
# 合并 aspectList(追加新维度,不重复)
|
||
existing_aspects = set(existing_step_data.get("aspectList", []))
|
||
new_aspects = [a for a in aspectList if a not in existing_aspects]
|
||
merged_aspect_list = existing_step_data.get("aspectList", []) + new_aspects
|
||
|
||
# 合并 agentScores(追加新维度的评分)
|
||
new_front_format = convert_score_table_to_front_format(scoreTable)
|
||
existing_agent_scores = existing_step_data.get("agentScores", {})
|
||
new_agent_scores = new_front_format.get("agentScores", {})
|
||
|
||
# 合并每个 agent 的评分
|
||
merged_agent_scores = {}
|
||
# 保留所有旧 agent 的评分
|
||
for agent, scores in existing_agent_scores.items():
|
||
merged_agent_scores[agent] = dict(scores)
|
||
# 追加新 agent 和新维度的评分
|
||
for agent, scores in new_agent_scores.items():
|
||
if agent not in merged_agent_scores:
|
||
merged_agent_scores[agent] = {}
|
||
for aspect, score_info in scores.items():
|
||
merged_agent_scores[agent][aspect] = score_info
|
||
|
||
# 构建合并后的数据
|
||
merged_step_data = {
|
||
"aspectList": merged_aspect_list,
|
||
"agentScores": merged_agent_scores
|
||
}
|
||
|
||
# 更新数据库
|
||
existing_scores[step_id] = merged_step_data
|
||
db.execute(
|
||
text("UPDATE multi_agent_tasks SET agent_scores = :scores WHERE task_id = :id"),
|
||
{"scores": json.dumps(existing_scores), "id": task_id}
|
||
)
|
||
db.commit()
|
||
print(f"[agent_select_modify_add_aspect] 已追加保存 agent_scores: step_id={step_id}, 新增维度={new_aspects}")
|
||
|
||
# 发送完成信号
|
||
emit('progress', {
|
||
'id': request_id,
|
||
'status': 'streaming',
|
||
'stage': 'complete',
|
||
'message': f'✅ 能力 "{newAspect}" 评分完成'
|
||
})
|
||
|
||
# 缓存结果
|
||
if USE_CACHE:
|
||
Request_Cache[requestIdentifier] = scoreTable
|
||
|
||
# 返回结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"task_id": task_id,
|
||
"scoreTable": scoreTable
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('agent_select_modify_delete_aspect')
|
||
def handle_agent_select_modify_delete_aspect_ws(data):
|
||
"""
|
||
WebSocket版本:删除评估维度
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"data": {
|
||
"task_id": "task-id",
|
||
"step_id": "step-id",
|
||
"aspect_name": "要删除的维度名称"
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
step_id = incoming_data.get('step_id')
|
||
aspect_name = incoming_data.get('aspect_name')
|
||
|
||
if not task_id or not aspect_name:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少必要参数 task_id 或 aspect_name'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取现有的 agent_scores 数据
|
||
task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {task_id}'
|
||
})
|
||
return
|
||
|
||
existing_scores = task.agent_scores or {}
|
||
|
||
# 如果指定了 step_id,则只更新该步骤的评分;否则更新所有步骤
|
||
if step_id and step_id in existing_scores:
|
||
step_scores = existing_scores[step_id]
|
||
# 从 aspectList 中移除该维度
|
||
if 'aspectList' in step_scores and aspect_name in step_scores['aspectList']:
|
||
step_scores['aspectList'] = [a for a in step_scores['aspectList'] if a != aspect_name]
|
||
# 从每个 agent 的评分中移除该维度
|
||
if 'agentScores' in step_scores:
|
||
for agent_name in step_scores['agentScores']:
|
||
if aspect_name in step_scores['agentScores'][agent_name]:
|
||
del step_scores['agentScores'][agent_name][aspect_name]
|
||
print(f"[agent_select_modify_delete_aspect] 已删除维度 from step_id={step_id}, 维度={aspect_name}")
|
||
else:
|
||
# 遍历所有步骤,移除该维度
|
||
for sid, step_scores in existing_scores.items():
|
||
if 'aspectList' in step_scores and aspect_name in step_scores['aspectList']:
|
||
step_scores['aspectList'] = [a for a in step_scores['aspectList'] if a != aspect_name]
|
||
if 'agentScores' in step_scores:
|
||
for agent_name in step_scores['agentScores']:
|
||
if aspect_name in step_scores['agentScores'][agent_name]:
|
||
del step_scores['agentScores'][agent_name][aspect_name]
|
||
print(f"[agent_select_modify_delete_aspect] 已删除所有步骤中的维度,维度={aspect_name}")
|
||
|
||
# 保存更新后的评分数据到数据库
|
||
db.execute(
|
||
text("UPDATE multi_agent_tasks SET agent_scores = :scores WHERE task_id = :id"),
|
||
{"scores": json.dumps(existing_scores), "id": task_id}
|
||
)
|
||
db.commit()
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": f"维度 '{aspect_name}' 删除成功",
|
||
"task_id": task_id,
|
||
"deleted_aspect": aspect_name
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('set_agents')
|
||
def handle_set_agents_ws(data):
|
||
"""
|
||
WebSocket版本:设置智能体
|
||
保存到 user_agents 数据库表
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
global AgentBoard, AgentProfile_Dict, yaml_data
|
||
|
||
try:
|
||
AgentBoard = incoming_data
|
||
AgentProfile_Dict = {}
|
||
|
||
# 保存到数据库
|
||
saved_agents = []
|
||
with get_db_context() as db:
|
||
for item in AgentBoard:
|
||
name = item["Name"]
|
||
if all(item.get(field) for field in ["apiUrl", "apiKey", "apiModel"]):
|
||
agent_config = {
|
||
"profile": item["Profile"],
|
||
"Icon": item.get("Icon", ""),
|
||
"Classification": item.get("Classification", ""),
|
||
"apiUrl": item["apiUrl"],
|
||
"apiKey": item["apiKey"],
|
||
"apiModel": item["apiModel"],
|
||
"useCustomAPI": True
|
||
}
|
||
else:
|
||
agent_config = {
|
||
"profile": item["Profile"],
|
||
"Icon": item.get("Icon", ""),
|
||
"Classification": item.get("Classification", ""),
|
||
"apiUrl": yaml_data.get("OPENAI_API_BASE"),
|
||
"apiKey": yaml_data.get("OPENAI_API_KEY"),
|
||
"apiModel": yaml_data.get("OPENAI_API_MODEL"),
|
||
"useCustomAPI": False
|
||
}
|
||
AgentProfile_Dict[name] = agent_config
|
||
|
||
# 保存到数据库(使用 upsert,相同 user_id + agent_name 则更新,否则创建)
|
||
user_id = item.get("user_id", "default_user")
|
||
agent = UserAgentCRUD.upsert(
|
||
db=db,
|
||
user_id=user_id,
|
||
agent_name=name,
|
||
agent_config=agent_config,
|
||
)
|
||
saved_agents.append(agent.to_dict())
|
||
|
||
# 返回结果
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"code": 200,
|
||
"content": "set agentboard successfully",
|
||
"saved_agents": saved_agents
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('get_agents')
|
||
def handle_get_agents_ws(data):
|
||
"""
|
||
WebSocket版本:获取智能体配置
|
||
从 user_agents 数据库表读取
|
||
如果没有 user_id,会生成新的 user_id 并返回
|
||
"""
|
||
request_id = data.get('id')
|
||
# 前端发送的数据在 data 字段中
|
||
incoming_data = data.get('data', {})
|
||
user_id = incoming_data.get('user_id') if isinstance(incoming_data, dict) else None
|
||
|
||
# 如果没有 user_id,生成新的
|
||
new_user_id = None
|
||
if not user_id:
|
||
new_user_id = str(uuid.uuid4())
|
||
user_id = new_user_id
|
||
print(f"[get_agents] 新生成 user_id: {new_user_id}")
|
||
else:
|
||
print(f"[get_agents] 接收到的 user_id: {user_id}")
|
||
|
||
try:
|
||
# 从数据库获取用户的智能体配置
|
||
with get_db_context() as db:
|
||
user_agents = UserAgentCRUD.get_by_user_id(db=db, user_id=user_id)
|
||
|
||
# 转换为前端期望的格式
|
||
agents = []
|
||
for ua in user_agents:
|
||
config = ua.agent_config or {}
|
||
agents.append({
|
||
'Name': ua.agent_name,
|
||
'Profile': config.get('profile', ''),
|
||
'Icon': config.get('Icon', ''),
|
||
'Classification': config.get('Classification', ''),
|
||
'apiUrl': config.get('apiUrl', ''),
|
||
'apiKey': config.get('apiKey', ''),
|
||
'apiModel': config.get('apiModel', ''),
|
||
})
|
||
|
||
response_data = {
|
||
'code': 200,
|
||
'content': 'get agents successfully',
|
||
'agents': agents
|
||
}
|
||
|
||
# 如果生成了新的 user_id,返回给前端
|
||
if new_user_id:
|
||
response_data['user_id'] = new_user_id
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': response_data
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('stop_generation')
|
||
def handle_stop_generation(data):
|
||
"""
|
||
WebSocket版本:停止生成任务
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "stop_generation",
|
||
"data": {
|
||
"goal": "任务描述"
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
goal = incoming_data.get('goal', '')
|
||
|
||
# TODO: 这里可以添加实际的停止逻辑
|
||
# 例如:设置全局停止标志,通知所有正在运行的生成任务停止
|
||
|
||
# 返回成功响应
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {"message": "已发送停止信号"}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('pause_execution')
|
||
def handle_pause_execution(data):
|
||
"""
|
||
WebSocket版本:暂停任务执行
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "pause_execution",
|
||
"data": {
|
||
"execution_id": "执行ID"
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
from AgentCoord.RehearsalEngine_V2.execution_state import execution_state_manager
|
||
|
||
execution_id = incoming_data.get('execution_id', '')
|
||
|
||
if not execution_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 execution_id'
|
||
})
|
||
return
|
||
|
||
# 调用执行状态管理器暂停
|
||
success = execution_state_manager.pause_execution(execution_id)
|
||
|
||
if success:
|
||
print(f"[WS] pause_execution 成功: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {"message": "已暂停执行,可随时继续"}
|
||
})
|
||
else:
|
||
print(f"[WS] pause_execution 失败: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'无法暂停'
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"[WS] pause_execution 异常: {str(e)}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('resume_execution')
|
||
def handle_resume_execution(data):
|
||
"""
|
||
WebSocket版本:恢复任务执行
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "resume_execution",
|
||
"data": {
|
||
"execution_id": "执行ID"
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
from AgentCoord.RehearsalEngine_V2.execution_state import execution_state_manager
|
||
|
||
execution_id = incoming_data.get('execution_id', '')
|
||
|
||
if not execution_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 execution_id'
|
||
})
|
||
return
|
||
|
||
# 调用执行状态管理器恢复
|
||
success = execution_state_manager.resume_execution(execution_id)
|
||
|
||
if success:
|
||
print(f"[WS] resume_execution 成功: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {"message": "已恢复执行"}
|
||
})
|
||
else:
|
||
print(f"[WS] resume_execution 失败: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'无法恢复'
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"[WS] resume_execution 异常: {str(e)}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('stop_execution')
|
||
def handle_stop_execution(data):
|
||
"""
|
||
WebSocket版本:停止任务执行
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "stop_execution",
|
||
"data": {
|
||
"execution_id": "执行ID"
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
|
||
try:
|
||
from AgentCoord.RehearsalEngine_V2.execution_state import execution_state_manager
|
||
|
||
execution_id = incoming_data.get('execution_id', '')
|
||
|
||
if not execution_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 execution_id'
|
||
})
|
||
return
|
||
|
||
# 调用执行状态管理器停止
|
||
success = execution_state_manager.stop_execution(execution_id)
|
||
|
||
if success:
|
||
print(f"[WS] stop_execution 成功: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {"message": "已停止执行"}
|
||
})
|
||
else:
|
||
print(f"[WS] stop_execution 失败: execution_id={execution_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'无法停止'
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"[WS] stop_execution 异常: {str(e)}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
# ==================== 历史记录管理 ====================
|
||
|
||
@socketio.on('get_plans')
|
||
def handle_get_plans(data):
|
||
"""
|
||
WebSocket版本:获取历史任务列表
|
||
"""
|
||
# socketio 会把数据包装多层:
|
||
# 前端发送: { id: 'get_plans-xxx', action: 'get_plans', data: { id: 'ws_req_xxx', user_id: 'xxx' } }
|
||
# socketio 包装后: data = { id: 'get_plans-xxx', action: 'get_plans', data: {...} }
|
||
request_id = data.get('id') # socketio 包装的 id,用于响应匹配
|
||
|
||
# 获取 user_id(从 data.data 中获取,因为前端发送时会包装)
|
||
incoming_data = data.get('data', {})
|
||
user_id = incoming_data.get('user_id') if isinstance(incoming_data, dict) else None
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取最近的任务记录,按 user_id 过滤
|
||
tasks = MultiAgentTaskCRUD.get_recent(db, limit=50, user_id=user_id)
|
||
|
||
# 转换为前端期望的格式
|
||
plans = []
|
||
for task in tasks:
|
||
# 兼容旧数据格式(branches 可能是数组)
|
||
branches_data = task.branches
|
||
if branches_data and isinstance(branches_data, list):
|
||
# 旧格式:数组,转换为新格式对象
|
||
branches_data = {
|
||
'flow_branches': branches_data,
|
||
'task_process_branches': {}
|
||
}
|
||
|
||
plans.append({
|
||
"id": task.task_id, # 以 task_id 为唯一标识
|
||
"general_goal": task.query or '未知任务',
|
||
"status": task.status.value if task.status else 'unknown',
|
||
"execution_count": task.execution_count or 0,
|
||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||
"is_pinned": task.is_pinned or False, # 置顶标志
|
||
# 完整数据用于恢复
|
||
"task_outline": task.task_outline,
|
||
"assigned_agents": task.assigned_agents,
|
||
"agent_scores": task.agent_scores,
|
||
"agents_info": task.agents_info,
|
||
"branches": branches_data or {}, # 分支数据(新格式对象)
|
||
})
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': plans
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('restore_plan')
|
||
def handle_restore_plan(data):
|
||
"""
|
||
WebSocket版本:恢复历史任务
|
||
"""
|
||
# socketio 包装: data = { id: 'restore_plan-xxx', action: 'restore_plan', data: { id: 'ws_req_xxx', data: {...} } }
|
||
request_id = data.get('id') # socketio 包装的 id
|
||
incoming_data = data.get('data', {}).get('data', {}) # 真正的请求数据
|
||
plan_id = incoming_data.get('plan_id')
|
||
|
||
if not plan_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 plan_id(task_id)'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 以 task_id 为唯一标识查询
|
||
task = MultiAgentTaskCRUD.get_by_id(db, plan_id)
|
||
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {plan_id}'
|
||
})
|
||
return
|
||
|
||
# 注意:恢复任务不增加执行次数,避免误统计
|
||
# execution_count 只在真正执行任务时增加
|
||
|
||
# 兼容旧数据格式(branches 可能是数组)
|
||
branches_data = task.branches
|
||
if branches_data and isinstance(branches_data, list):
|
||
# 旧格式:数组,转换为新格式对象
|
||
branches_data = {
|
||
'flow_branches': branches_data,
|
||
'task_process_branches': {}
|
||
}
|
||
|
||
# 返回完整数据用于恢复
|
||
restored_data = {
|
||
"id": task.task_id,
|
||
"general_goal": task.query or '未知任务',
|
||
"status": task.status.value if task.status else 'unknown',
|
||
"execution_count": task.execution_count or 0,
|
||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||
# 完整恢复数据
|
||
"task_outline": task.task_outline,
|
||
"assigned_agents": task.assigned_agents,
|
||
"agent_scores": task.agent_scores,
|
||
"agents_info": task.agents_info,
|
||
"branches": branches_data or {}, # 分支数据(新格式对象)
|
||
"rehearsal_log": task.rehearsal_log, # 排练日志(完整执行数据,用于恢复执行状态)
|
||
}
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': restored_data
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('get_agent_scores')
|
||
def handle_get_agent_scores(data):
|
||
"""
|
||
WebSocket版本:获取指定任务的评分数据
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "get_agent_scores",
|
||
"data": {
|
||
"task_id": "task-id"
|
||
}
|
||
}
|
||
|
||
返回格式(与前端 ITaskScoreData 一致):
|
||
{
|
||
"task_id": "xxx",
|
||
"agent_scores": {
|
||
"stepId1": {
|
||
"aspectList": ["专业性", "协作能力"],
|
||
"agentScores": {"Agent-A": {"专业性": {"score": 4.5, "reason": "..."}}},
|
||
"timestamp": 1699999999999
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {task_id}'
|
||
})
|
||
return
|
||
|
||
# 返回评分数据
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"task_id": task_id,
|
||
"agent_scores": task.agent_scores or {}
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('delete_plan')
|
||
def handle_delete_plan(data):
|
||
"""
|
||
WebSocket版本:删除历史任务
|
||
"""
|
||
# socketio 包装: data = { id: 'delete_plan-xxx', action: 'delete_plan', data: { id: 'ws_req_xxx', data: {...} } }
|
||
request_id = data.get('id') # socketio 包装的 id
|
||
incoming_data = data.get('data', {}).get('data', {}) # 真正的请求数据
|
||
plan_id = incoming_data.get('plan_id')
|
||
|
||
if not plan_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 plan_id(task_id)'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 以 task_id 为唯一标识删除
|
||
success = MultiAgentTaskCRUD.delete(db, plan_id)
|
||
|
||
if not success:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在或删除失败: {plan_id}'
|
||
})
|
||
return
|
||
|
||
# 通知所有客户端刷新历史列表
|
||
socketio.emit('history_updated', {'task_id': plan_id})
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {"message": "删除成功"}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('pin_plan')
|
||
def handle_pin_plan(data):
|
||
"""
|
||
WebSocket版本:置顶/取消置顶历史任务
|
||
"""
|
||
# socketio 包装: data = { id: 'pin_plan-xxx', action: 'pin_plan', data: { id: 'ws_req_xxx', data: {...} } }
|
||
request_id = data.get('id') # socketio 包装的 id
|
||
incoming_data = data.get('data', {}).get('data', {}) # 真正的请求数据
|
||
plan_id = incoming_data.get('plan_id')
|
||
is_pinned = incoming_data.get('is_pinned', True) # 默认为置顶
|
||
|
||
if not plan_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 plan_id(task_id)'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
task = MultiAgentTaskCRUD.update_is_pinned(db, plan_id, is_pinned)
|
||
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {plan_id}'
|
||
})
|
||
return
|
||
|
||
# 通知所有客户端刷新历史列表
|
||
socketio.emit('history_updated', {'task_id': plan_id})
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {"message": "置顶成功" if is_pinned else "取消置顶成功"}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
import secrets
|
||
|
||
|
||
@socketio.on('share_plan')
|
||
def handle_share_plan(data):
|
||
"""
|
||
WebSocket版本:分享任务
|
||
"""
|
||
# socketio 包装: data = { id: 'share_plan-xxx', action: 'share_plan', data: { id: 'ws_req_xxx', data: {...} } }
|
||
request_id = data.get('id') # socketio 包装的 id
|
||
incoming_data = data.get('data', {}).get('data', {}) # 真正的请求数据
|
||
plan_id = incoming_data.get('plan_id')
|
||
expiration_days = incoming_data.get('expiration_days', 7) # 默认为7天,0表示永久
|
||
extraction_code = incoming_data.get('extraction_code', '') # 提取码
|
||
auto_fill_code = incoming_data.get('auto_fill_code', True) # 是否在链接中自动填充提取码
|
||
|
||
if not plan_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 plan_id(task_id)'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取任务详情
|
||
task = MultiAgentTaskCRUD.get_by_id(db, plan_id)
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {plan_id}'
|
||
})
|
||
return
|
||
|
||
# 生成唯一分享 token
|
||
share_token = secrets.token_urlsafe(16)
|
||
|
||
# 设置过期时间
|
||
if expiration_days == 0:
|
||
# 永久有效
|
||
expires_at = None
|
||
else:
|
||
expires_at = datetime.now(timezone.utc) + timedelta(days=expiration_days)
|
||
|
||
# 准备分享数据(脱敏处理,移除敏感信息)
|
||
task_data = {
|
||
"general_goal": task.query,
|
||
"task_outline": task.task_outline,
|
||
"assigned_agents": task.assigned_agents,
|
||
"agent_scores": task.agent_scores,
|
||
"agents_info": task.agents_info,
|
||
"branches": task.branches,
|
||
"result": task.result,
|
||
"rehearsal_log": task.rehearsal_log,
|
||
"status": task.status.value if task.status else None,
|
||
}
|
||
|
||
# 创建分享记录
|
||
share = PlanShareCRUD.create(
|
||
db=db,
|
||
share_token=share_token,
|
||
task_id=plan_id,
|
||
task_data=task_data,
|
||
expires_at=expires_at,
|
||
extraction_code=extraction_code.upper() if extraction_code else None,
|
||
)
|
||
|
||
# 生成分享链接(根据auto_fill_code决定是否在URL中携带提取码)
|
||
if extraction_code and auto_fill_code:
|
||
share_url = f"/share/{share_token}?code={extraction_code.upper()}"
|
||
else:
|
||
share_url = f"/share/{share_token}"
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"share_url": share_url,
|
||
"share_token": share_token,
|
||
"extraction_code": extraction_code.upper() if extraction_code else None,
|
||
"auto_fill_code": auto_fill_code,
|
||
"task_id": plan_id,
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('import_shared_plan')
|
||
def handle_import_shared_plan(data):
|
||
"""
|
||
WebSocket版本:导入分享的任务到自己的历史记录
|
||
"""
|
||
# socketio 包装: data = { id: 'import_shared_plan-xxx', action: 'import_shared_plan', data: { id: 'ws_req_xxx', data: {...} } }
|
||
request_id = data.get('id') # socketio 包装的 id
|
||
incoming_data = data.get('data', {}).get('data', {}) # 真正的请求数据
|
||
share_token = incoming_data.get('share_token')
|
||
user_id = incoming_data.get('user_id')
|
||
|
||
if not share_token:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 share_token'
|
||
})
|
||
return
|
||
|
||
if not user_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 user_id'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取分享记录
|
||
share = PlanShareCRUD.get_by_token(db, share_token)
|
||
if not share:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '分享链接无效或已失效'
|
||
})
|
||
return
|
||
|
||
# 检查是否过期
|
||
if share.expires_at and share.expires_at.replace(tzinfo=None) < datetime.now():
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '分享链接已过期'
|
||
})
|
||
return
|
||
|
||
# 获取分享的任务数据
|
||
task_data = share.task_data
|
||
|
||
# 生成新的 task_id(因为是导入到自己的账号)
|
||
import uuid
|
||
new_task_id = str(uuid.uuid4())
|
||
|
||
# 创建新的任务记录
|
||
task = MultiAgentTaskCRUD.create(
|
||
db=db,
|
||
task_id=new_task_id,
|
||
user_id=user_id,
|
||
query=task_data.get("general_goal", ""),
|
||
agents_info=task_data.get("agents_info", []),
|
||
task_outline=task_data.get("task_outline"),
|
||
assigned_agents=task_data.get("assigned_agents"),
|
||
agent_scores=task_data.get("agent_scores"),
|
||
result=task_data.get("result"),
|
||
)
|
||
|
||
# 如果有分支数据,也保存
|
||
if task_data.get("branches"):
|
||
MultiAgentTaskCRUD.update_branches(db, new_task_id, task_data["branches"])
|
||
|
||
# 如果有执行日志,也保存
|
||
if task_data.get("rehearsal_log"):
|
||
MultiAgentTaskCRUD.update_rehearsal_log(db, new_task_id, task_data["rehearsal_log"])
|
||
|
||
# 增加分享的查看次数
|
||
PlanShareCRUD.increment_view_count(db, share_token)
|
||
|
||
# 通知所有客户端刷新历史列表
|
||
socketio.emit('history_updated', {'user_id': user_id})
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "导入成功",
|
||
"task_id": new_task_id,
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('save_branches')
|
||
def handle_save_branches(data):
|
||
"""
|
||
WebSocket版本:保存任务分支数据
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "save_branches",
|
||
"data": {
|
||
"task_id": "task-id",
|
||
"branches": [...] // 分支数据数组
|
||
}
|
||
}
|
||
|
||
数据库存储格式:
|
||
{
|
||
"branches": {
|
||
"flow_branches": [...], // 任务大纲探索分支
|
||
"task_process_branches": {...} // 任务过程分支(可能不存在)
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
flow_branches = incoming_data.get('branches', [])
|
||
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取现有的 branches 数据
|
||
existing_task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
if existing_task:
|
||
# 使用深拷贝避免修改共享引用
|
||
existing_branches = copy.deepcopy(existing_task.branches) if existing_task.branches else {}
|
||
|
||
# 保留现有的 task_process_branches(关键:不要覆盖已有的任务过程分支)
|
||
task_process_branches = existing_branches.get('task_process_branches', {}) if isinstance(existing_branches, dict) else {}
|
||
|
||
# 构建新的 branches 数据
|
||
new_branches = {
|
||
'flow_branches': flow_branches,
|
||
'task_process_branches': task_process_branches
|
||
}
|
||
|
||
# 更新数据库
|
||
MultiAgentTaskCRUD.update_branches(db, task_id, new_branches)
|
||
print(f"[save_branches] 已保存分支数据到数据库,task_id={task_id}, flow_branches_count={len(flow_branches)}, task_process_count={len(task_process_branches)}")
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "分支数据保存成功",
|
||
"branches_count": len(flow_branches)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('save_task_process_branches')
|
||
def handle_save_task_process_branches(data):
|
||
"""
|
||
WebSocket版本:保存任务过程分支数据
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "save_task_process_branches",
|
||
"data": {
|
||
"task_id": "task-id", // 大任务ID(数据库主键)
|
||
"branches": {
|
||
"stepId-1": {
|
||
'["AgentA","AgentB"]': [{...}, {...}],
|
||
'["AgentC"]': [...]
|
||
},
|
||
"stepId-2": {...}
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
branches = incoming_data.get('branches', {})
|
||
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取现有的 branches 数据
|
||
existing_task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
|
||
if existing_task:
|
||
# 使用深拷贝避免修改共享引用
|
||
existing_branches = copy.deepcopy(existing_task.branches) if existing_task.branches else {}
|
||
|
||
# 保留现有的 flow_branches
|
||
existing_flow_branches = existing_branches.get('flow_branches', []) if isinstance(existing_branches, dict) else []
|
||
|
||
# 合并 task_process_branches(新数据与旧数据合并,而不是覆盖)
|
||
existing_task_process = existing_branches.get('task_process_branches', {}) if isinstance(existing_branches, dict) else {}
|
||
incoming_task_process = branches
|
||
|
||
# 合并逻辑:对于每个 stepId,将新分支追加到已有分支中
|
||
merged_task_process = dict(existing_task_process)
|
||
for stepId, stepData in incoming_task_process.items():
|
||
if stepId in merged_task_process:
|
||
# stepId 已存在,合并 agentGroupKey 下的分支数组
|
||
existing_agent_data = merged_task_process[stepId]
|
||
incoming_agent_data = stepData
|
||
for agentKey, newBranches in incoming_agent_data.items():
|
||
if agentKey in existing_agent_data:
|
||
# 合并分支(去重,根据 branch.id)
|
||
existing_ids = {b.get('id') for b in existing_agent_data[agentKey] if b.get('id')}
|
||
for newBranch in newBranches:
|
||
if newBranch.get('id') not in existing_ids:
|
||
existing_agent_data[agentKey].append(newBranch)
|
||
else:
|
||
existing_agent_data[agentKey] = newBranches
|
||
else:
|
||
merged_task_process[stepId] = stepData
|
||
|
||
# 构建新 branches 数据
|
||
if isinstance(existing_branches, dict):
|
||
new_branches = dict(existing_branches)
|
||
new_branches['task_process_branches'] = merged_task_process
|
||
new_branches['flow_branches'] = existing_flow_branches
|
||
else:
|
||
new_branches = {
|
||
'task_process_branches': merged_task_process,
|
||
'flow_branches': existing_flow_branches if isinstance(existing_flow_branches, list) else []
|
||
}
|
||
|
||
# 直接更新数据库
|
||
existing_task.branches = new_branches
|
||
db.flush() # 显式刷新,确保 SQLAlchemy 检测到变化
|
||
db.commit()
|
||
|
||
print(f"[save_task_process_branches] 保存完成,stepIds: {list(merged_task_process.keys())}")
|
||
else:
|
||
print(f"[save_task_process_branches] 警告: 找不到任务,task_id={task_id}")
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "任务过程分支数据保存成功",
|
||
"step_count": len(branches)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('delete_task_process_branch')
|
||
def handle_delete_task_process_branch(data):
|
||
"""
|
||
WebSocket版本:删除任务过程分支数据
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "delete_task_process_branch",
|
||
"data": {
|
||
"task_id": "task-id", // 大任务ID(数据库主键)
|
||
"stepId": "step-id", // 小任务ID
|
||
"branchId": "branch-id" // 要删除的分支ID
|
||
}
|
||
}
|
||
|
||
数据库存储格式:
|
||
{
|
||
"branches": {
|
||
"flow_branches": [...],
|
||
"task_process_branches": {
|
||
"stepId-1": {
|
||
'["AgentA","AgentB"]': [{...分支数据...}]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
step_id = incoming_data.get('stepId')
|
||
branch_id = incoming_data.get('branchId')
|
||
|
||
if not task_id or not step_id or not branch_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少必要参数:task_id, stepId, branchId'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取现有的 branches 数据
|
||
existing_task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
|
||
if existing_task:
|
||
# 使用深拷贝避免修改共享引用
|
||
existing_branches = copy.deepcopy(existing_task.branches) if existing_task.branches else {}
|
||
|
||
if isinstance(existing_branches, dict):
|
||
# 获取现有的 task_process_branches
|
||
task_process_branches = existing_branches.get('task_process_branches', {})
|
||
|
||
if step_id in task_process_branches:
|
||
# 获取该 stepId 下的所有 agent 分支
|
||
step_branches = task_process_branches[step_id]
|
||
|
||
# 遍历所有 agentGroupKey,删除对应分支
|
||
for agent_key, branches_list in step_branches.items():
|
||
# 过滤掉要删除的分支
|
||
filtered_branches = [b for b in branches_list if b.get('id') != branch_id]
|
||
|
||
if len(filtered_branches) != len(branches_list):
|
||
# 有分支被删除,更新数据
|
||
if filtered_branches:
|
||
step_branches[agent_key] = filtered_branches
|
||
else:
|
||
# 如果该 agentKey 下没有分支了,删除该 key
|
||
del step_branches[agent_key]
|
||
|
||
# 如果该 stepId 下没有分支了,删除该 stepId
|
||
if not step_branches:
|
||
del task_process_branches[step_id]
|
||
|
||
# 更新 branches 数据
|
||
existing_branches['task_process_branches'] = task_process_branches
|
||
|
||
# 直接更新数据库
|
||
existing_task.branches = existing_branches
|
||
db.flush()
|
||
db.commit()
|
||
|
||
print(f"[delete_task_process_branch] 删除成功,task_id={task_id}, step_id={step_id}, branch_id={branch_id}")
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "分支删除成功",
|
||
"deleted_branch_id": branch_id
|
||
}
|
||
})
|
||
return
|
||
|
||
# 如果找不到对应的分支
|
||
print(f"[delete_task_process_branch] 警告: 找不到要删除的分支,task_id={task_id}, step_id={step_id}, branch_id={branch_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '未找到要删除的分支'
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('delete_task_process_node')
|
||
def handle_delete_task_process_node(data):
|
||
"""
|
||
WebSocket版本:删除任务过程分支中的单个节点
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "delete_task_process_node",
|
||
"data": {
|
||
"task_id": "task-id", // 大任务ID(数据库主键)
|
||
"stepId": "step-id", // 小任务ID
|
||
"branchId": "branch-id", // 分支ID
|
||
"nodeId": "node-id" // 要删除的节点ID
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
step_id = incoming_data.get('stepId')
|
||
branch_id = incoming_data.get('branchId')
|
||
node_id = incoming_data.get('nodeId')
|
||
edges = incoming_data.get('edges', []) # 更新后的 edges 数据
|
||
|
||
if not task_id or not step_id or not branch_id or not node_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少必要参数:task_id, stepId, branchId, nodeId'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取现有的 branches 数据
|
||
existing_task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
|
||
if existing_task:
|
||
# 使用深拷贝避免修改共享引用
|
||
existing_branches = copy.deepcopy(existing_task.branches) if existing_task.branches else {}
|
||
|
||
if isinstance(existing_branches, dict):
|
||
# 获取现有的 task_process_branches
|
||
task_process_branches = existing_branches.get('task_process_branches', {})
|
||
|
||
if step_id in task_process_branches:
|
||
step_branches = task_process_branches[step_id]
|
||
|
||
# 遍历所有 agentGroupKey 下的分支
|
||
for agent_key, branches_list in step_branches.items():
|
||
if isinstance(branches_list, list):
|
||
for branch in branches_list:
|
||
if branch.get('id') == branch_id:
|
||
# 找到目标分支,删除指定的节点
|
||
nodes = branch.get('nodes', [])
|
||
tasks = branch.get('tasks', [])
|
||
|
||
# 找到并删除节点
|
||
for i, node in enumerate(nodes):
|
||
if node.get('id') == node_id:
|
||
nodes.pop(i)
|
||
if i < len(tasks):
|
||
tasks.pop(i)
|
||
break
|
||
|
||
# 更新分支数据(包括 nodes, tasks, edges)
|
||
branch['nodes'] = nodes
|
||
branch['tasks'] = tasks
|
||
branch['edges'] = edges # 使用前端传入的更新后的 edges
|
||
break
|
||
|
||
# 更新 branches 数据
|
||
existing_branches['task_process_branches'] = task_process_branches
|
||
|
||
# 直接更新数据库
|
||
existing_task.branches = existing_branches
|
||
db.flush()
|
||
db.commit()
|
||
|
||
print(f"[delete_task_process_node] 删除成功,task_id={task_id}, step_id={step_id}, branch_id={branch_id}, node_id={node_id}")
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "节点删除成功",
|
||
"deleted_node_id": node_id
|
||
}
|
||
})
|
||
return
|
||
|
||
# 如果找不到对应的节点
|
||
print(f"[delete_task_process_node] 警告: 找不到要删除的节点,task_id={task_id}, step_id={step_id}, branch_id={branch_id}, node_id={node_id}")
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '未找到要删除的节点'
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('save_task_outline')
|
||
def handle_save_task_outline(data):
|
||
"""
|
||
WebSocket版本:保存任务大纲数据
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "save_task_outline",
|
||
"data": {
|
||
"task_id": "task-id",
|
||
"task_outline": {...}
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
task_outline = incoming_data.get('task_outline')
|
||
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 更新任务大纲
|
||
MultiAgentTaskCRUD.update_task_outline(db, task_id, task_outline)
|
||
print(f"[save_task_outline] 已保存任务大纲到数据库,task_id={task_id}")
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "任务大纲保存成功"
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('update_assigned_agents')
|
||
def handle_update_assigned_agents(data):
|
||
"""
|
||
WebSocket版本:更新指定步骤的 assigned_agents
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "update_assigned_agents",
|
||
"data": {
|
||
"task_id": "task-id", // 大任务ID(数据库主键)
|
||
"step_id": "step-id", // 步骤级ID(小任务UUID)
|
||
"agents": ["AgentA", "AgentB"], // 选中的 agent 列表
|
||
"confirmed_groups": [["AgentA"], ["AgentA", "AgentB"]], // 可选:确认的 agent 组合列表
|
||
"save_combination": true // 可选:是否同时保存该组合的 TaskProcess(由 fill_step_task_process 处理)
|
||
}
|
||
}
|
||
"""
|
||
import json
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
step_id = incoming_data.get('step_id')
|
||
agents = incoming_data.get('agents', [])
|
||
confirmed_groups = incoming_data.get('confirmed_groups', [])
|
||
agent_combinations = incoming_data.get('agent_combinations', {})
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
if not step_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 step_id 参数'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取现有任务
|
||
task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {task_id}'
|
||
})
|
||
return
|
||
|
||
# 合并更新 assigned_agents,确保是 dict 类型
|
||
raw_assigned = task.assigned_agents
|
||
existing_assigned = raw_assigned if isinstance(raw_assigned, dict) else {}
|
||
|
||
# 确保步骤数据结构存在
|
||
if step_id not in existing_assigned:
|
||
existing_assigned[step_id] = {}
|
||
|
||
# 确保子结构存在
|
||
if "current" not in existing_assigned[step_id]:
|
||
existing_assigned[step_id]["current"] = []
|
||
if "confirmed_groups" not in existing_assigned[step_id]:
|
||
existing_assigned[step_id]["confirmed_groups"] = []
|
||
if "agent_combinations" not in existing_assigned[step_id]:
|
||
existing_assigned[step_id]["agent_combinations"] = {}
|
||
|
||
# 更新 current agents(当前选中的组合)
|
||
if agents:
|
||
existing_assigned[step_id]["current"] = agents
|
||
|
||
# 更新 confirmed_groups(确认的组合列表)
|
||
if confirmed_groups:
|
||
existing_assigned[step_id]["confirmed_groups"] = confirmed_groups
|
||
# 清理已删除分组的 agent_combinations 数据
|
||
existing_combinations = existing_assigned[step_id].get("agent_combinations", {})
|
||
new_combinations_keys = {json.dumps(list(group), sort_keys=True) for group in confirmed_groups}
|
||
keys_to_remove = [k for k in existing_combinations.keys() if k not in new_combinations_keys]
|
||
for key in keys_to_remove:
|
||
del existing_combinations[key]
|
||
existing_assigned[step_id]["agent_combinations"] = existing_combinations
|
||
|
||
# 更新 agent_combinations(保存 TaskProcess 数据)
|
||
if agent_combinations:
|
||
# 合并新旧数据
|
||
existing_combinations = existing_assigned[step_id].get("agent_combinations", {})
|
||
for key, value in agent_combinations.items():
|
||
existing_combinations[key] = value
|
||
existing_assigned[step_id]["agent_combinations"] = existing_combinations
|
||
|
||
db.execute(
|
||
text("UPDATE multi_agent_tasks SET assigned_agents = :assigned WHERE task_id = :id"),
|
||
{"assigned": json.dumps(existing_assigned), "id": task_id}
|
||
)
|
||
db.commit()
|
||
print(f"[update_assigned_agents] 已保存: task_id={task_id}, step_id={step_id}")
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
"message": "assigned_agents 更新成功",
|
||
"task_id": task_id,
|
||
"step_id": step_id,
|
||
"agents": agents,
|
||
"confirmed_groups": confirmed_groups
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
# ==================== 导出功能 ====================
|
||
|
||
# 导出类型配置
|
||
EXPORT_TYPE_CONFIG = {
|
||
"doc": {"ext": ".docx", "mime": "application/vnd.openxmlformats-officedocument.wordprocessingml.document"},
|
||
"markdown": {"ext": ".md", "mime": "text/markdown"},
|
||
"excel": {"ext": ".xlsx", "mime": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
|
||
"ppt": {"ext": ".pptx", "mime": "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
|
||
"mindmap": {"ext": ".md", "mime": "text/markdown"}, # 思维导图导出为 Markdown 格式
|
||
"infographic": {"ext": ".html", "mime": "text/html"}, # 信息图先用 html
|
||
}
|
||
|
||
|
||
def ensure_export_dir(task_id: str) -> str:
|
||
"""确保导出目录存在"""
|
||
task_dir = os.path.join(EXPORT_DIR, task_id)
|
||
os.makedirs(task_dir, exist_ok=True)
|
||
return task_dir
|
||
|
||
|
||
def generate_export_file_name(task_name: str, export_type: str) -> str:
|
||
"""生成导出文件名"""
|
||
from datetime import datetime
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
# 清理文件名中的非法字符
|
||
safe_name = "".join(c for c in task_name if c.isalnum() or c in (' ', '-', '_')).strip()
|
||
return f"{safe_name}_{export_type}_{timestamp}"
|
||
|
||
|
||
@socketio.on('export')
|
||
def handle_export(data):
|
||
"""
|
||
WebSocket:处理导出请求
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "export",
|
||
"data": {
|
||
"task_id": "task-id", // 任务ID
|
||
"export_type": "doc", // 导出类型: doc/markdown/excel/ppt/mindmap/infographic
|
||
"user_id": "user-id", // 用户ID
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
export_type = incoming_data.get('export_type')
|
||
user_id = incoming_data.get('user_id', 'anonymous')
|
||
|
||
# 参数验证
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
if not export_type or export_type not in EXPORT_TYPE_CONFIG:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'无效的导出类型: {export_type}'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
# 获取任务数据
|
||
task = MultiAgentTaskCRUD.get_by_id(db, task_id)
|
||
if not task:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'任务不存在: {task_id}'
|
||
})
|
||
return
|
||
|
||
# 准备导出数据
|
||
from datetime import datetime
|
||
current_date = datetime.now().strftime('%Y年%m月%d日')
|
||
export_data = {
|
||
'task_name': task.query or '未命名任务',
|
||
'task_content': task.query or '',
|
||
'task_outline': task.task_outline,
|
||
'result': task.result,
|
||
'agents_info': task.agents_info,
|
||
'assigned_agents': task.assigned_agents,
|
||
'rehearsal_log': task.rehearsal_log,
|
||
'agent_scores': task.agent_scores,
|
||
'user_id': user_id,
|
||
'date': current_date,
|
||
}
|
||
|
||
# 生成文件名
|
||
file_name_base = generate_export_file_name(export_data['task_name'], export_type)
|
||
config = EXPORT_TYPE_CONFIG[export_type]
|
||
file_name = file_name_base + config['ext']
|
||
file_path = os.path.join(ensure_export_dir(task_id), file_name)
|
||
|
||
# 生成文件内容
|
||
# 使用 ExportFactory 来生成各种格式的文件
|
||
try:
|
||
success = ExportFactory.export(export_type, export_data, file_path)
|
||
if not success:
|
||
# 如果导出失败,创建空文件占位
|
||
with open(file_path, 'wb') as f:
|
||
f.write(b'')
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': f'导出类型 {export_type} 不支持或生成失败'
|
||
})
|
||
return
|
||
except Exception as e:
|
||
print(f"导出文件失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
# 导出失败时创建空文件
|
||
with open(file_path, 'wb') as f:
|
||
f.write(b'')
|
||
|
||
# 获取文件大小
|
||
file_size = os.path.getsize(file_path)
|
||
|
||
# 生成访问URL(基于文件路径)
|
||
# 相对路径用于静态访问
|
||
relative_path = os.path.join('uploads', 'exports', task_id, file_name)
|
||
file_url = f"/{relative_path.replace(os.sep, '/')}"
|
||
|
||
# 保存导出记录到数据库
|
||
record = ExportRecordCRUD.create(
|
||
db=db,
|
||
task_id=task_id,
|
||
user_id=user_id,
|
||
export_type=export_type,
|
||
file_name=file_name,
|
||
file_path=file_path,
|
||
file_url=file_url,
|
||
file_size=file_size,
|
||
)
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
'record_id': record.id,
|
||
'file_name': file_name,
|
||
'file_url': file_url,
|
||
'file_size': file_size,
|
||
'export_type': export_type,
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
@socketio.on('get_export_list')
|
||
def handle_get_export_list(data):
|
||
"""
|
||
WebSocket:获取导出记录列表
|
||
|
||
请求格式:
|
||
{
|
||
"id": "request-id",
|
||
"action": "get_export_list",
|
||
"data": {
|
||
"task_id": "task-id", // 任务ID
|
||
}
|
||
}
|
||
"""
|
||
request_id = data.get('id')
|
||
incoming_data = data.get('data', {})
|
||
task_id = incoming_data.get('task_id')
|
||
|
||
if not task_id:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': '缺少 task_id 参数'
|
||
})
|
||
return
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
records = ExportRecordCRUD.get_by_task_id(db, task_id)
|
||
export_list = [record.to_dict() for record in records]
|
||
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'success',
|
||
'data': {
|
||
'list': export_list,
|
||
'total': len(export_list)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
emit('response', {
|
||
'id': request_id,
|
||
'status': 'error',
|
||
'error': str(e)
|
||
})
|
||
|
||
|
||
# ==================== REST API 接口 ====================
|
||
|
||
@app.route('/api/export/<int:record_id>/download', methods=['GET'])
|
||
def download_export(record_id: int):
|
||
"""下载导出文件"""
|
||
try:
|
||
with get_db_context() as db:
|
||
record = ExportRecordCRUD.get_by_id(db, record_id)
|
||
if not record:
|
||
return jsonify({'error': '导出记录不存在'}), 404
|
||
|
||
if not os.path.exists(record.file_path):
|
||
return jsonify({'error': '文件不存在'}), 404
|
||
|
||
# 发送文件
|
||
config = EXPORT_TYPE_CONFIG.get(record.export_type, {})
|
||
mime_type = config.get('mime', 'application/octet-stream')
|
||
|
||
return send_file(
|
||
record.file_path,
|
||
mimetype=mime_type,
|
||
as_attachment=True,
|
||
download_name=record.file_name
|
||
)
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/export/<int:record_id>/preview', methods=['GET'])
|
||
def preview_export(record_id: int):
|
||
"""预览导出文件"""
|
||
try:
|
||
with get_db_context() as db:
|
||
record = ExportRecordCRUD.get_by_id(db, record_id)
|
||
if not record:
|
||
return jsonify({'error': '导出记录不存在'}), 404
|
||
|
||
if not os.path.exists(record.file_path):
|
||
return jsonify({'error': '文件不存在'}), 404
|
||
|
||
# 根据文件类型返回不同的 Content-Type
|
||
config = EXPORT_TYPE_CONFIG.get(record.export_type, {})
|
||
mime_type = config.get('mime', 'application/octet-stream')
|
||
|
||
# 读取文件内容
|
||
if record.export_type == 'markdown':
|
||
with open(record.file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
return jsonify({'content': content, 'type': 'markdown'})
|
||
elif record.export_type == 'mindmap':
|
||
# Markdown 格式的思维导图
|
||
with open(record.file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
return jsonify({'content': content, 'type': 'mindmap'})
|
||
elif record.export_type in ['doc', 'docx']:
|
||
# Word 文件,直接返回文件流(使用 Flask 的 send_file)
|
||
return send_file(
|
||
record.file_path,
|
||
mimetype='application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||
as_attachment=False,
|
||
download_name=record.file_name
|
||
)
|
||
elif record.export_type in ['excel', 'xlsx', 'xls']:
|
||
# Excel 文件,直接返回文件流
|
||
return send_file(
|
||
record.file_path,
|
||
mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||
as_attachment=False,
|
||
download_name=record.file_name
|
||
)
|
||
elif record.export_type in ['ppt', 'pptx']:
|
||
# PPT 文件,直接返回文件流
|
||
return send_file(
|
||
record.file_path,
|
||
mimetype='application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||
as_attachment=False,
|
||
download_name=record.file_name
|
||
)
|
||
else:
|
||
# 其他类型返回文件路径,前端自行处理
|
||
return jsonify({
|
||
'file_url': record.file_url,
|
||
'file_name': record.file_name,
|
||
'type': record.export_type
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/export/<int:record_id>/share', methods=['GET'])
|
||
def share_export(record_id: int):
|
||
"""生成分享链接"""
|
||
try:
|
||
with get_db_context() as db:
|
||
record = ExportRecordCRUD.get_by_id(db, record_id)
|
||
if not record:
|
||
return jsonify({'error': '导出记录不存在'}), 404
|
||
|
||
# 生成分享Token(简化实现,直接用记录ID)
|
||
share_token = f"export_{record.id}_{int(record.created_at.timestamp())}"
|
||
share_url = f"/share/{share_token}"
|
||
|
||
return jsonify({
|
||
'share_url': share_url,
|
||
'file_name': record.file_name,
|
||
'expired_at': None # TODO: 可以添加过期时间
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/export/<int:record_id>/share/info', methods=['GET'])
|
||
def get_share_info(record_id: int):
|
||
"""获取分享文件信息(无需登录验证)"""
|
||
try:
|
||
with get_db_context() as db:
|
||
record = ExportRecordCRUD.get_by_id(db, record_id)
|
||
if not record:
|
||
return jsonify({'error': '文件不存在或已失效'}), 404
|
||
|
||
return jsonify({
|
||
'file_name': record.file_name,
|
||
'export_type': record.export_type,
|
||
'created_at': record.created_at.isoformat() if record.created_at else None,
|
||
'file_size': record.file_size or 0,
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
# ==================== 任务分享页面 ====================
|
||
|
||
@app.route('/share/<share_token>', methods=['GET'])
|
||
def get_shared_plan_page(share_token: str):
|
||
"""获取分享任务页面(无需登录验证)"""
|
||
try:
|
||
with get_db_context() as db:
|
||
share = PlanShareCRUD.get_by_token(db, share_token)
|
||
if not share:
|
||
return jsonify({'error': '分享链接无效或已失效'}), 404
|
||
|
||
# 检查是否过期
|
||
if share.expires_at and share.expires_at.replace(tzinfo=None) < datetime.now():
|
||
return jsonify({'error': '分享链接已过期'}), 404
|
||
|
||
# 增加查看次数
|
||
PlanShareCRUD.increment_view_count(db, share_token)
|
||
|
||
# 返回分享数据
|
||
task_data = share.task_data
|
||
return jsonify({
|
||
'share_token': share_token,
|
||
'task_id': share.task_id,
|
||
'task_data': task_data,
|
||
'created_at': share.created_at.isoformat() if share.created_at else None,
|
||
'view_count': share.view_count,
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/share/<share_token>/check', methods=['GET'])
|
||
def check_share_code(share_token: str):
|
||
"""检查分享链接是否需要提取码"""
|
||
try:
|
||
with get_db_context() as db:
|
||
share = PlanShareCRUD.get_by_token(db, share_token)
|
||
if not share:
|
||
return jsonify({'error': '分享链接无效或已失效'}), 404
|
||
|
||
# 检查是否过期
|
||
if share.expires_at and share.expires_at.replace(tzinfo=None) < datetime.now():
|
||
return jsonify({'error': '分享链接已过期'}), 404
|
||
|
||
# 如果有提取码,则需要提取码
|
||
need_code = bool(share.extraction_code)
|
||
return jsonify({
|
||
'need_code': need_code,
|
||
'has_extraction_code': bool(share.extraction_code)
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/share/<share_token>', methods=['GET'])
|
||
def get_shared_plan_info(share_token: str):
|
||
"""获取分享任务详情(API 接口,无需登录验证)"""
|
||
# 获取URL参数中的提取码
|
||
code = request.args.get('code', '').upper()
|
||
|
||
try:
|
||
with get_db_context() as db:
|
||
share = PlanShareCRUD.get_by_token(db, share_token)
|
||
if not share:
|
||
return jsonify({'error': '分享链接无效或已失效'}), 404
|
||
|
||
# 检查是否过期
|
||
if share.expires_at and share.expires_at.replace(tzinfo=None) < datetime.now():
|
||
return jsonify({'error': '分享链接已过期'}), 404
|
||
|
||
# 验证提取码
|
||
if share.extraction_code:
|
||
if not code:
|
||
return jsonify({'error': '请输入提取码'}), 403
|
||
if code != share.extraction_code:
|
||
return jsonify({'error': '提取码错误'}), 403
|
||
|
||
# 增加查看次数
|
||
PlanShareCRUD.increment_view_count(db, share_token)
|
||
|
||
# 返回分享数据
|
||
task_data = share.task_data
|
||
return jsonify({
|
||
'share_token': share_token,
|
||
'task_id': share.task_id,
|
||
'task_data': task_data,
|
||
'created_at': share.created_at.isoformat() if share.created_at else None,
|
||
'expires_at': share.expires_at.isoformat() if share.expires_at else None,
|
||
'view_count': share.view_count,
|
||
'extraction_code': share.extraction_code,
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/share/import', methods=['POST'])
|
||
def import_shared_plan():
|
||
"""导入分享的任务到自己的历史记录(HTTP API,无需 WebSocket)"""
|
||
try:
|
||
data = request.get_json()
|
||
share_token = data.get('share_token')
|
||
user_id = data.get('user_id')
|
||
|
||
if not share_token:
|
||
return jsonify({'error': '缺少 share_token'}), 400
|
||
|
||
if not user_id:
|
||
return jsonify({'error': '缺少 user_id,请先登录'}), 401
|
||
|
||
with get_db_context() as db:
|
||
# 获取分享记录
|
||
share = PlanShareCRUD.get_by_token(db, share_token)
|
||
if not share:
|
||
return jsonify({'error': '分享链接无效或已失效'}), 404
|
||
|
||
# 检查是否过期
|
||
if share.expires_at and share.expires_at.replace(tzinfo=None) < datetime.now():
|
||
return jsonify({'error': '分享链接已过期'}), 404
|
||
|
||
# 获取分享的任务数据
|
||
task_data = share.task_data
|
||
|
||
# 生成新的 task_id(因为是导入到自己的账号)
|
||
new_task_id = str(uuid.uuid4())
|
||
|
||
# 创建新的任务记录
|
||
task = MultiAgentTaskCRUD.create(
|
||
db=db,
|
||
task_id=new_task_id,
|
||
user_id=user_id,
|
||
query=task_data.get("general_goal", ""),
|
||
agents_info=task_data.get("agents_info", []),
|
||
task_outline=task_data.get("task_outline"),
|
||
assigned_agents=task_data.get("assigned_agents"),
|
||
agent_scores=task_data.get("agent_scores"),
|
||
result=task_data.get("result"),
|
||
)
|
||
|
||
# 如果有分支数据,也保存
|
||
if task_data.get("branches"):
|
||
MultiAgentTaskCRUD.update_branches(db, new_task_id, task_data["branches"])
|
||
|
||
# 如果有执行日志,也保存
|
||
if task_data.get("rehearsal_log"):
|
||
MultiAgentTaskCRUD.update_rehearsal_log(db, new_task_id, task_data["rehearsal_log"])
|
||
|
||
# 增加分享的查看次数
|
||
PlanShareCRUD.increment_view_count(db, share_token)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '导入成功',
|
||
'task_id': new_task_id,
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/api/export/<int:record_id>', methods=['DELETE'])
|
||
def delete_export(record_id: int):
|
||
"""删除导出记录"""
|
||
try:
|
||
with get_db_context() as db:
|
||
record = ExportRecordCRUD.get_by_id(db, record_id)
|
||
if not record:
|
||
return jsonify({'error': '导出记录不存在'}), 404
|
||
|
||
# 删除物理文件
|
||
if os.path.exists(record.file_path):
|
||
os.remove(record.file_path)
|
||
|
||
# 删除数据库记录
|
||
ExportRecordCRUD.delete(db, record_id)
|
||
|
||
return jsonify({'success': True, 'message': '删除成功'})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(
|
||
description="start the backend for AgentCoord"
|
||
)
|
||
parser.add_argument(
|
||
"--port",
|
||
type=int,
|
||
default=8000,
|
||
help="set the port number, 8000 by default.",
|
||
)
|
||
args = parser.parse_args()
|
||
init()
|
||
# 使用 socketio.run 替代 app.run,支持WebSocket
|
||
socketio.run(app, host="0.0.0.0", port=args.port, debug=True, allow_unsafe_werkzeug=True) |