578 lines
18 KiB
Python
578 lines
18 KiB
Python
"""
|
||
数据库 CRUD 操作
|
||
封装所有数据库操作方法 (基于 DATABASE_DESIGN.md)
|
||
"""
|
||
|
||
import copy
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from typing import List, Optional
|
||
from sqlalchemy.orm import Session
|
||
|
||
from .models import MultiAgentTask, UserAgent, ExportRecord, PlanShare
|
||
|
||
|
||
class MultiAgentTaskCRUD:
|
||
"""多智能体任务 CRUD 操作"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
db: Session,
|
||
task_id: Optional[str] = None, # 可选,如果为 None 则自动生成
|
||
user_id: str = "",
|
||
query: str = "",
|
||
agents_info: list = [],
|
||
task_outline: Optional[dict] = None,
|
||
assigned_agents: Optional[list] = None,
|
||
agent_scores: Optional[dict] = None,
|
||
result: Optional[str] = None,
|
||
) -> MultiAgentTask:
|
||
"""创建任务记录"""
|
||
task = MultiAgentTask(
|
||
task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的
|
||
user_id=user_id,
|
||
query=query,
|
||
agents_info=agents_info,
|
||
task_outline=task_outline,
|
||
assigned_agents=assigned_agents,
|
||
agent_scores=agent_scores,
|
||
result=result,
|
||
)
|
||
db.add(task)
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]:
|
||
"""根据任务 ID 获取记录"""
|
||
return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
|
||
@staticmethod
|
||
def get_by_user_id(
|
||
db: Session, user_id: str, limit: int = 50, offset: int = 0
|
||
) -> List[MultiAgentTask]:
|
||
"""根据用户 ID 获取任务记录"""
|
||
return (
|
||
db.query(MultiAgentTask)
|
||
.filter(MultiAgentTask.user_id == user_id)
|
||
.order_by(MultiAgentTask.created_at.desc())
|
||
.offset(offset)
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(
|
||
db: Session, limit: int = 20, offset: int = 0, user_id: str = None
|
||
) -> List[MultiAgentTask]:
|
||
"""获取最近的任务记录,置顶的排在最前面"""
|
||
query = db.query(MultiAgentTask)
|
||
# 按 user_id 过滤
|
||
if user_id:
|
||
query = query.filter(MultiAgentTask.user_id == user_id)
|
||
return (
|
||
query
|
||
.order_by(MultiAgentTask.is_pinned.desc(), MultiAgentTask.created_at.desc())
|
||
.offset(offset)
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def update_result(
|
||
db: Session, task_id: str, result: list
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新任务结果"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.result = result if result else []
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_task_outline(
|
||
db: Session, task_id: str, task_outline: dict
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新任务大纲"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.task_outline = task_outline
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_assigned_agents(
|
||
db: Session, task_id: str, assigned_agents: dict
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新分配的智能体(步骤名 -> agent列表)"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.assigned_agents = assigned_agents
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_agent_scores(
|
||
db: Session, task_id: str, agent_scores: dict
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新智能体评分(合并模式,追加新步骤的评分)"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
# 合并现有评分数据和新评分数据
|
||
existing_scores = task.agent_scores or {}
|
||
merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据
|
||
task.agent_scores = merged_scores
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_status(
|
||
db: Session, task_id: str, status: str
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新任务状态"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.status = status
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]:
|
||
"""增加任务执行次数"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.execution_count = (task.execution_count or 0) + 1
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_generation_id(
|
||
db: Session, task_id: str, generation_id: str
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新生成 ID"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.generation_id = generation_id
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_execution_id(
|
||
db: Session, task_id: str, execution_id: str
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新执行 ID"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.execution_id = execution_id
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_rehearsal_log(
|
||
db: Session, task_id: str, rehearsal_log: list
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新排练日志"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.rehearsal_log = rehearsal_log if rehearsal_log else []
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_is_pinned(
|
||
db: Session, task_id: str, is_pinned: bool
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新任务置顶状态"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
task.is_pinned = is_pinned
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def append_rehearsal_log(
|
||
db: Session, task_id: str, log_entry: dict
|
||
) -> Optional[MultiAgentTask]:
|
||
"""追加排练日志条目"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
current_log = task.rehearsal_log or []
|
||
if isinstance(current_log, list):
|
||
current_log.append(log_entry)
|
||
else:
|
||
current_log = [log_entry]
|
||
task.rehearsal_log = current_log
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def update_branches(
|
||
db: Session, task_id: str, branches
|
||
) -> Optional[MultiAgentTask]:
|
||
"""更新任务分支数据
|
||
支持两种格式:
|
||
- list: 旧格式,直接覆盖
|
||
- dict: 新格式 { flow_branches: [...], task_process_branches: {...} }
|
||
两个 key 独立保存,互不干扰。
|
||
"""
|
||
import copy
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
if isinstance(branches, dict):
|
||
# 新格式:字典,独立保存两个 key,互不干扰
|
||
# 使用深拷贝避免引用共享问题
|
||
existing = copy.deepcopy(task.branches) if task.branches else {}
|
||
if isinstance(existing, dict):
|
||
# 如果只更新 flow_branches,保留已有的 task_process_branches
|
||
if 'flow_branches' in branches and 'task_process_branches' not in branches:
|
||
branches['task_process_branches'] = existing.get('task_process_branches', {})
|
||
# 如果只更新 task_process_branches,保留已有的 flow_branches
|
||
if 'task_process_branches' in branches and 'flow_branches' not in branches:
|
||
branches['flow_branches'] = existing.get('flow_branches', [])
|
||
task.branches = branches
|
||
else:
|
||
# 旧格式:列表
|
||
task.branches = branches if branches else []
|
||
db.commit()
|
||
db.refresh(task)
|
||
return task
|
||
|
||
@staticmethod
|
||
def get_branches(db: Session, task_id: str) -> Optional[list]:
|
||
"""获取任务分支数据"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
return task.branches or []
|
||
return []
|
||
|
||
@staticmethod
|
||
def get_by_status(
|
||
db: Session, status: str, limit: int = 50, offset: int = 0
|
||
) -> List[MultiAgentTask]:
|
||
"""根据状态获取任务记录"""
|
||
return (
|
||
db.query(MultiAgentTask)
|
||
.filter(MultiAgentTask.status == status)
|
||
.order_by(MultiAgentTask.created_at.desc())
|
||
.offset(offset)
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_generation_id(
|
||
db: Session, generation_id: str
|
||
) -> List[MultiAgentTask]:
|
||
"""根据生成 ID 获取任务记录"""
|
||
return (
|
||
db.query(MultiAgentTask)
|
||
.filter(MultiAgentTask.generation_id == generation_id)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_execution_id(
|
||
db: Session, execution_id: str
|
||
) -> List[MultiAgentTask]:
|
||
"""根据执行 ID 获取任务记录"""
|
||
return (
|
||
db.query(MultiAgentTask)
|
||
.filter(MultiAgentTask.execution_id == execution_id)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def delete(db: Session, task_id: str) -> bool:
|
||
"""删除任务记录"""
|
||
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
|
||
if task:
|
||
db.delete(task)
|
||
db.commit()
|
||
return True
|
||
return False
|
||
|
||
|
||
class UserAgentCRUD:
|
||
"""用户智能体配置 CRUD 操作"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
db: Session,
|
||
user_id: str,
|
||
agent_name: str,
|
||
agent_config: dict,
|
||
) -> UserAgent:
|
||
"""创建用户智能体配置"""
|
||
agent = UserAgent(
|
||
id=str(uuid.uuid4()),
|
||
user_id=user_id,
|
||
agent_name=agent_name,
|
||
agent_config=agent_config,
|
||
)
|
||
db.add(agent)
|
||
db.commit()
|
||
db.refresh(agent)
|
||
return agent
|
||
|
||
@staticmethod
|
||
def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]:
|
||
"""根据 ID 获取配置"""
|
||
return db.query(UserAgent).filter(UserAgent.id == agent_id).first()
|
||
|
||
@staticmethod
|
||
def get_by_user_id(
|
||
db: Session, user_id: str, limit: int = 50
|
||
) -> List[UserAgent]:
|
||
"""根据用户 ID 获取所有智能体配置"""
|
||
return (
|
||
db.query(UserAgent)
|
||
.filter(UserAgent.user_id == user_id)
|
||
.order_by(UserAgent.created_at.desc())
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_name(
|
||
db: Session, user_id: str, agent_name: str
|
||
) -> List[UserAgent]:
|
||
"""根据用户 ID 和智能体名称获取配置"""
|
||
return (
|
||
db.query(UserAgent)
|
||
.filter(
|
||
UserAgent.user_id == user_id,
|
||
UserAgent.agent_name == agent_name,
|
||
)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def update_config(
|
||
db: Session, agent_id: str, agent_config: dict
|
||
) -> Optional[UserAgent]:
|
||
"""更新智能体配置"""
|
||
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
|
||
if agent:
|
||
agent.agent_config = agent_config
|
||
db.commit()
|
||
db.refresh(agent)
|
||
return agent
|
||
|
||
@staticmethod
|
||
def delete(db: Session, agent_id: str) -> bool:
|
||
"""删除智能体配置"""
|
||
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
|
||
if agent:
|
||
db.delete(agent)
|
||
db.commit()
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def upsert(
|
||
db: Session,
|
||
user_id: str,
|
||
agent_name: str,
|
||
agent_config: dict,
|
||
) -> UserAgent:
|
||
"""更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性)
|
||
|
||
如果已存在相同 user_id 和 agent_name 的记录,则更新配置;
|
||
否则创建新记录。
|
||
"""
|
||
existing = (
|
||
db.query(UserAgent)
|
||
.filter(
|
||
UserAgent.user_id == user_id,
|
||
UserAgent.agent_name == agent_name,
|
||
)
|
||
.first()
|
||
)
|
||
if existing:
|
||
# 更新现有记录
|
||
existing.agent_config = agent_config
|
||
db.commit()
|
||
db.refresh(existing)
|
||
return existing
|
||
else:
|
||
# 创建新记录
|
||
agent = UserAgent(
|
||
id=str(uuid.uuid4()),
|
||
user_id=user_id,
|
||
agent_name=agent_name,
|
||
agent_config=agent_config,
|
||
)
|
||
db.add(agent)
|
||
db.commit()
|
||
db.refresh(agent)
|
||
return agent
|
||
|
||
|
||
class ExportRecordCRUD:
|
||
"""导出记录 CRUD 操作"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
db: Session,
|
||
task_id: str,
|
||
user_id: str,
|
||
export_type: str,
|
||
file_name: str,
|
||
file_path: str,
|
||
file_url: str = "",
|
||
file_size: int = 0,
|
||
) -> ExportRecord:
|
||
"""创建导出记录"""
|
||
record = ExportRecord(
|
||
task_id=task_id,
|
||
user_id=user_id,
|
||
export_type=export_type,
|
||
file_name=file_name,
|
||
file_path=file_path,
|
||
file_url=file_url,
|
||
file_size=file_size,
|
||
)
|
||
db.add(record)
|
||
db.commit()
|
||
db.refresh(record)
|
||
return record
|
||
|
||
@staticmethod
|
||
def get_by_id(db: Session, record_id: int) -> Optional[ExportRecord]:
|
||
"""根据 ID 获取记录"""
|
||
return db.query(ExportRecord).filter(ExportRecord.id == record_id).first()
|
||
|
||
@staticmethod
|
||
def get_by_task_id(
|
||
db: Session, task_id: str, limit: int = 50
|
||
) -> List[ExportRecord]:
|
||
"""根据任务 ID 获取导出记录列表"""
|
||
return (
|
||
db.query(ExportRecord)
|
||
.filter(ExportRecord.task_id == task_id)
|
||
.order_by(ExportRecord.created_at.desc())
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_user_id(
|
||
db: Session, user_id: str, limit: int = 50
|
||
) -> List[ExportRecord]:
|
||
"""根据用户 ID 获取导出记录列表"""
|
||
return (
|
||
db.query(ExportRecord)
|
||
.filter(ExportRecord.user_id == user_id)
|
||
.order_by(ExportRecord.created_at.desc())
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def delete(db: Session, record_id: int) -> bool:
|
||
"""删除导出记录"""
|
||
record = db.query(ExportRecord).filter(ExportRecord.id == record_id).first()
|
||
if record:
|
||
db.delete(record)
|
||
db.commit()
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def delete_by_task_id(db: Session, task_id: str) -> bool:
|
||
"""删除任务的所有导出记录"""
|
||
records = db.query(ExportRecord).filter(ExportRecord.task_id == task_id).all()
|
||
if records:
|
||
for record in records:
|
||
db.delete(record)
|
||
db.commit()
|
||
return True
|
||
return False
|
||
|
||
|
||
class PlanShareCRUD:
|
||
"""任务分享 CRUD 操作"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
db: Session,
|
||
share_token: str,
|
||
task_id: str,
|
||
task_data: dict,
|
||
expires_at: Optional[datetime] = None,
|
||
extraction_code: Optional[str] = None,
|
||
) -> PlanShare:
|
||
"""创建分享记录"""
|
||
share = PlanShare(
|
||
share_token=share_token,
|
||
extraction_code=extraction_code,
|
||
task_id=task_id,
|
||
task_data=task_data,
|
||
expires_at=expires_at,
|
||
)
|
||
db.add(share)
|
||
db.commit()
|
||
db.refresh(share)
|
||
return share
|
||
|
||
@staticmethod
|
||
def get_by_token(db: Session, share_token: str) -> Optional[PlanShare]:
|
||
"""根据 token 获取分享记录"""
|
||
return db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
|
||
|
||
@staticmethod
|
||
def get_by_task_id(
|
||
db: Session, task_id: str, limit: int = 10
|
||
) -> List[PlanShare]:
|
||
"""根据任务 ID 获取分享记录列表"""
|
||
return (
|
||
db.query(PlanShare)
|
||
.filter(PlanShare.task_id == task_id)
|
||
.order_by(PlanShare.created_at.desc())
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
@staticmethod
|
||
def increment_view_count(db: Session, share_token: str) -> Optional[PlanShare]:
|
||
"""增加查看次数"""
|
||
share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
|
||
if share:
|
||
share.view_count = (share.view_count or 0) + 1
|
||
db.commit()
|
||
db.refresh(share)
|
||
return share
|
||
|
||
@staticmethod
|
||
def delete(db: Session, share_token: str) -> bool:
|
||
"""删除分享记录"""
|
||
share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
|
||
if share:
|
||
db.delete(share)
|
||
db.commit()
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def delete_by_task_id(db: Session, task_id: str) -> bool:
|
||
"""删除任务的所有分享记录"""
|
||
shares = db.query(PlanShare).filter(PlanShare.task_id == task_id).all()
|
||
if shares:
|
||
for share in shares:
|
||
db.delete(share)
|
||
db.commit()
|
||
return True
|
||
return False
|