Files
AgentCoord/backend/db/crud.py
2026-03-12 13:35:04 +08:00

578 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数据库 CRUD 操作
封装所有数据库操作方法 (基于 DATABASE_DESIGN.md)
"""
import copy
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
from .models import MultiAgentTask, UserAgent, ExportRecord, PlanShare
class MultiAgentTaskCRUD:
"""多智能体任务 CRUD 操作"""
@staticmethod
def create(
db: Session,
task_id: Optional[str] = None, # 可选,如果为 None 则自动生成
user_id: str = "",
query: str = "",
agents_info: list = [],
task_outline: Optional[dict] = None,
assigned_agents: Optional[list] = None,
agent_scores: Optional[dict] = None,
result: Optional[str] = None,
) -> MultiAgentTask:
"""创建任务记录"""
task = MultiAgentTask(
task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的
user_id=user_id,
query=query,
agents_info=agents_info,
task_outline=task_outline,
assigned_agents=assigned_agents,
agent_scores=agent_scores,
result=result,
)
db.add(task)
db.commit()
db.refresh(task)
return task
@staticmethod
def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""根据任务 ID 获取记录"""
return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据用户 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.user_id == user_id)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_recent(
db: Session, limit: int = 20, offset: int = 0, user_id: str = None
) -> List[MultiAgentTask]:
"""获取最近的任务记录,置顶的排在最前面"""
query = db.query(MultiAgentTask)
# 按 user_id 过滤
if user_id:
query = query.filter(MultiAgentTask.user_id == user_id)
return (
query
.order_by(MultiAgentTask.is_pinned.desc(), MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def update_result(
db: Session, task_id: str, result: list
) -> Optional[MultiAgentTask]:
"""更新任务结果"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.result = result if result else []
db.commit()
db.refresh(task)
return task
@staticmethod
def update_task_outline(
db: Session, task_id: str, task_outline: dict
) -> Optional[MultiAgentTask]:
"""更新任务大纲"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.task_outline = task_outline
db.commit()
db.refresh(task)
return task
@staticmethod
def update_assigned_agents(
db: Session, task_id: str, assigned_agents: dict
) -> Optional[MultiAgentTask]:
"""更新分配的智能体(步骤名 -> agent列表"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.assigned_agents = assigned_agents
db.commit()
db.refresh(task)
return task
@staticmethod
def update_agent_scores(
db: Session, task_id: str, agent_scores: dict
) -> Optional[MultiAgentTask]:
"""更新智能体评分(合并模式,追加新步骤的评分)"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
# 合并现有评分数据和新评分数据
existing_scores = task.agent_scores or {}
merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据
task.agent_scores = merged_scores
db.commit()
db.refresh(task)
return task
@staticmethod
def update_status(
db: Session, task_id: str, status: str
) -> Optional[MultiAgentTask]:
"""更新任务状态"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.status = status
db.commit()
db.refresh(task)
return task
@staticmethod
def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""增加任务执行次数"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_count = (task.execution_count or 0) + 1
db.commit()
db.refresh(task)
return task
@staticmethod
def update_generation_id(
db: Session, task_id: str, generation_id: str
) -> Optional[MultiAgentTask]:
"""更新生成 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.generation_id = generation_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_execution_id(
db: Session, task_id: str, execution_id: str
) -> Optional[MultiAgentTask]:
"""更新执行 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_id = execution_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_rehearsal_log(
db: Session, task_id: str, rehearsal_log: list
) -> Optional[MultiAgentTask]:
"""更新排练日志"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.rehearsal_log = rehearsal_log if rehearsal_log else []
db.commit()
db.refresh(task)
return task
@staticmethod
def update_is_pinned(
db: Session, task_id: str, is_pinned: bool
) -> Optional[MultiAgentTask]:
"""更新任务置顶状态"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.is_pinned = is_pinned
db.commit()
db.refresh(task)
return task
@staticmethod
def append_rehearsal_log(
db: Session, task_id: str, log_entry: dict
) -> Optional[MultiAgentTask]:
"""追加排练日志条目"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
current_log = task.rehearsal_log or []
if isinstance(current_log, list):
current_log.append(log_entry)
else:
current_log = [log_entry]
task.rehearsal_log = current_log
db.commit()
db.refresh(task)
return task
@staticmethod
def update_branches(
db: Session, task_id: str, branches
) -> Optional[MultiAgentTask]:
"""更新任务分支数据
支持两种格式:
- list: 旧格式,直接覆盖
- dict: 新格式 { flow_branches: [...], task_process_branches: {...} }
两个 key 独立保存,互不干扰。
"""
import copy
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
if isinstance(branches, dict):
# 新格式:字典,独立保存两个 key互不干扰
# 使用深拷贝避免引用共享问题
existing = copy.deepcopy(task.branches) if task.branches else {}
if isinstance(existing, dict):
# 如果只更新 flow_branches保留已有的 task_process_branches
if 'flow_branches' in branches and 'task_process_branches' not in branches:
branches['task_process_branches'] = existing.get('task_process_branches', {})
# 如果只更新 task_process_branches保留已有的 flow_branches
if 'task_process_branches' in branches and 'flow_branches' not in branches:
branches['flow_branches'] = existing.get('flow_branches', [])
task.branches = branches
else:
# 旧格式:列表
task.branches = branches if branches else []
db.commit()
db.refresh(task)
return task
@staticmethod
def get_branches(db: Session, task_id: str) -> Optional[list]:
"""获取任务分支数据"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
return task.branches or []
return []
@staticmethod
def get_by_status(
db: Session, status: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据状态获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.status == status)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_by_generation_id(
db: Session, generation_id: str
) -> List[MultiAgentTask]:
"""根据生成 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.generation_id == generation_id)
.all()
)
@staticmethod
def get_by_execution_id(
db: Session, execution_id: str
) -> List[MultiAgentTask]:
"""根据执行 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.execution_id == execution_id)
.all()
)
@staticmethod
def delete(db: Session, task_id: str) -> bool:
"""删除任务记录"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
db.delete(task)
db.commit()
return True
return False
class UserAgentCRUD:
"""用户智能体配置 CRUD 操作"""
@staticmethod
def create(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""创建用户智能体配置"""
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent
@staticmethod
def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]:
"""根据 ID 获取配置"""
return db.query(UserAgent).filter(UserAgent.id == agent_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50
) -> List[UserAgent]:
"""根据用户 ID 获取所有智能体配置"""
return (
db.query(UserAgent)
.filter(UserAgent.user_id == user_id)
.order_by(UserAgent.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_by_name(
db: Session, user_id: str, agent_name: str
) -> List[UserAgent]:
"""根据用户 ID 和智能体名称获取配置"""
return (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.all()
)
@staticmethod
def update_config(
db: Session, agent_id: str, agent_config: dict
) -> Optional[UserAgent]:
"""更新智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
agent.agent_config = agent_config
db.commit()
db.refresh(agent)
return agent
@staticmethod
def delete(db: Session, agent_id: str) -> bool:
"""删除智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
db.delete(agent)
db.commit()
return True
return False
@staticmethod
def upsert(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性)
如果已存在相同 user_id 和 agent_name 的记录,则更新配置;
否则创建新记录。
"""
existing = (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.first()
)
if existing:
# 更新现有记录
existing.agent_config = agent_config
db.commit()
db.refresh(existing)
return existing
else:
# 创建新记录
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent
class ExportRecordCRUD:
"""导出记录 CRUD 操作"""
@staticmethod
def create(
db: Session,
task_id: str,
user_id: str,
export_type: str,
file_name: str,
file_path: str,
file_url: str = "",
file_size: int = 0,
) -> ExportRecord:
"""创建导出记录"""
record = ExportRecord(
task_id=task_id,
user_id=user_id,
export_type=export_type,
file_name=file_name,
file_path=file_path,
file_url=file_url,
file_size=file_size,
)
db.add(record)
db.commit()
db.refresh(record)
return record
@staticmethod
def get_by_id(db: Session, record_id: int) -> Optional[ExportRecord]:
"""根据 ID 获取记录"""
return db.query(ExportRecord).filter(ExportRecord.id == record_id).first()
@staticmethod
def get_by_task_id(
db: Session, task_id: str, limit: int = 50
) -> List[ExportRecord]:
"""根据任务 ID 获取导出记录列表"""
return (
db.query(ExportRecord)
.filter(ExportRecord.task_id == task_id)
.order_by(ExportRecord.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50
) -> List[ExportRecord]:
"""根据用户 ID 获取导出记录列表"""
return (
db.query(ExportRecord)
.filter(ExportRecord.user_id == user_id)
.order_by(ExportRecord.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def delete(db: Session, record_id: int) -> bool:
"""删除导出记录"""
record = db.query(ExportRecord).filter(ExportRecord.id == record_id).first()
if record:
db.delete(record)
db.commit()
return True
return False
@staticmethod
def delete_by_task_id(db: Session, task_id: str) -> bool:
"""删除任务的所有导出记录"""
records = db.query(ExportRecord).filter(ExportRecord.task_id == task_id).all()
if records:
for record in records:
db.delete(record)
db.commit()
return True
return False
class PlanShareCRUD:
"""任务分享 CRUD 操作"""
@staticmethod
def create(
db: Session,
share_token: str,
task_id: str,
task_data: dict,
expires_at: Optional[datetime] = None,
extraction_code: Optional[str] = None,
) -> PlanShare:
"""创建分享记录"""
share = PlanShare(
share_token=share_token,
extraction_code=extraction_code,
task_id=task_id,
task_data=task_data,
expires_at=expires_at,
)
db.add(share)
db.commit()
db.refresh(share)
return share
@staticmethod
def get_by_token(db: Session, share_token: str) -> Optional[PlanShare]:
"""根据 token 获取分享记录"""
return db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
@staticmethod
def get_by_task_id(
db: Session, task_id: str, limit: int = 10
) -> List[PlanShare]:
"""根据任务 ID 获取分享记录列表"""
return (
db.query(PlanShare)
.filter(PlanShare.task_id == task_id)
.order_by(PlanShare.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def increment_view_count(db: Session, share_token: str) -> Optional[PlanShare]:
"""增加查看次数"""
share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
if share:
share.view_count = (share.view_count or 0) + 1
db.commit()
db.refresh(share)
return share
@staticmethod
def delete(db: Session, share_token: str) -> bool:
"""删除分享记录"""
share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
if share:
db.delete(share)
db.commit()
return True
return False
@staticmethod
def delete_by_task_id(db: Session, task_id: str) -> bool:
"""删除任务的所有分享记录"""
shares = db.query(PlanShare).filter(PlanShare.task_id == task_id).all()
if shares:
for share in shares:
db.delete(share)
db.commit()
return True
return False