""" 数据库 CRUD 操作 封装所有数据库操作方法 (基于 DATABASE_DESIGN.md) """ import copy import uuid from datetime import datetime, timezone from typing import List, Optional from sqlalchemy.orm import Session from .models import MultiAgentTask, UserAgent, ExportRecord, PlanShare class MultiAgentTaskCRUD: """多智能体任务 CRUD 操作""" @staticmethod def create( db: Session, task_id: Optional[str] = None, # 可选,如果为 None 则自动生成 user_id: str = "", query: str = "", agents_info: list = [], task_outline: Optional[dict] = None, assigned_agents: Optional[list] = None, agent_scores: Optional[dict] = None, result: Optional[str] = None, ) -> MultiAgentTask: """创建任务记录""" task = MultiAgentTask( task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的 user_id=user_id, query=query, agents_info=agents_info, task_outline=task_outline, assigned_agents=assigned_agents, agent_scores=agent_scores, result=result, ) db.add(task) db.commit() db.refresh(task) return task @staticmethod def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]: """根据任务 ID 获取记录""" return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() @staticmethod def get_by_user_id( db: Session, user_id: str, limit: int = 50, offset: int = 0 ) -> List[MultiAgentTask]: """根据用户 ID 获取任务记录""" return ( db.query(MultiAgentTask) .filter(MultiAgentTask.user_id == user_id) .order_by(MultiAgentTask.created_at.desc()) .offset(offset) .limit(limit) .all() ) @staticmethod def get_recent( db: Session, limit: int = 20, offset: int = 0, user_id: str = None ) -> List[MultiAgentTask]: """获取最近的任务记录,置顶的排在最前面""" query = db.query(MultiAgentTask) # 按 user_id 过滤 if user_id: query = query.filter(MultiAgentTask.user_id == user_id) return ( query .order_by(MultiAgentTask.is_pinned.desc(), MultiAgentTask.created_at.desc()) .offset(offset) .limit(limit) .all() ) @staticmethod def update_result( db: Session, task_id: str, result: list ) -> Optional[MultiAgentTask]: """更新任务结果""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.result = result if result else [] db.commit() db.refresh(task) return task @staticmethod def update_task_outline( db: Session, task_id: str, task_outline: dict ) -> Optional[MultiAgentTask]: """更新任务大纲""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.task_outline = task_outline db.commit() db.refresh(task) return task @staticmethod def update_assigned_agents( db: Session, task_id: str, assigned_agents: dict ) -> Optional[MultiAgentTask]: """更新分配的智能体(步骤名 -> agent列表)""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.assigned_agents = assigned_agents db.commit() db.refresh(task) return task @staticmethod def update_agent_scores( db: Session, task_id: str, agent_scores: dict ) -> Optional[MultiAgentTask]: """更新智能体评分(合并模式,追加新步骤的评分)""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: # 合并现有评分数据和新评分数据 existing_scores = task.agent_scores or {} merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据 task.agent_scores = merged_scores db.commit() db.refresh(task) return task @staticmethod def update_status( db: Session, task_id: str, status: str ) -> Optional[MultiAgentTask]: """更新任务状态""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.status = status db.commit() db.refresh(task) return task @staticmethod def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]: """增加任务执行次数""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.execution_count = (task.execution_count or 0) + 1 db.commit() db.refresh(task) return task @staticmethod def update_generation_id( db: Session, task_id: str, generation_id: str ) -> Optional[MultiAgentTask]: """更新生成 ID""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.generation_id = generation_id db.commit() db.refresh(task) return task @staticmethod def update_execution_id( db: Session, task_id: str, execution_id: str ) -> Optional[MultiAgentTask]: """更新执行 ID""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.execution_id = execution_id db.commit() db.refresh(task) return task @staticmethod def update_rehearsal_log( db: Session, task_id: str, rehearsal_log: list ) -> Optional[MultiAgentTask]: """更新排练日志""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.rehearsal_log = rehearsal_log if rehearsal_log else [] db.commit() db.refresh(task) return task @staticmethod def update_is_pinned( db: Session, task_id: str, is_pinned: bool ) -> Optional[MultiAgentTask]: """更新任务置顶状态""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: task.is_pinned = is_pinned db.commit() db.refresh(task) return task @staticmethod def append_rehearsal_log( db: Session, task_id: str, log_entry: dict ) -> Optional[MultiAgentTask]: """追加排练日志条目""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: current_log = task.rehearsal_log or [] if isinstance(current_log, list): current_log.append(log_entry) else: current_log = [log_entry] task.rehearsal_log = current_log db.commit() db.refresh(task) return task @staticmethod def update_branches( db: Session, task_id: str, branches ) -> Optional[MultiAgentTask]: """更新任务分支数据 支持两种格式: - list: 旧格式,直接覆盖 - dict: 新格式 { flow_branches: [...], task_process_branches: {...} } 两个 key 独立保存,互不干扰。 """ import copy task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: if isinstance(branches, dict): # 新格式:字典,独立保存两个 key,互不干扰 # 使用深拷贝避免引用共享问题 existing = copy.deepcopy(task.branches) if task.branches else {} if isinstance(existing, dict): # 如果只更新 flow_branches,保留已有的 task_process_branches if 'flow_branches' in branches and 'task_process_branches' not in branches: branches['task_process_branches'] = existing.get('task_process_branches', {}) # 如果只更新 task_process_branches,保留已有的 flow_branches if 'task_process_branches' in branches and 'flow_branches' not in branches: branches['flow_branches'] = existing.get('flow_branches', []) task.branches = branches else: # 旧格式:列表 task.branches = branches if branches else [] db.commit() db.refresh(task) return task @staticmethod def get_branches(db: Session, task_id: str) -> Optional[list]: """获取任务分支数据""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: return task.branches or [] return [] @staticmethod def get_by_status( db: Session, status: str, limit: int = 50, offset: int = 0 ) -> List[MultiAgentTask]: """根据状态获取任务记录""" return ( db.query(MultiAgentTask) .filter(MultiAgentTask.status == status) .order_by(MultiAgentTask.created_at.desc()) .offset(offset) .limit(limit) .all() ) @staticmethod def get_by_generation_id( db: Session, generation_id: str ) -> List[MultiAgentTask]: """根据生成 ID 获取任务记录""" return ( db.query(MultiAgentTask) .filter(MultiAgentTask.generation_id == generation_id) .all() ) @staticmethod def get_by_execution_id( db: Session, execution_id: str ) -> List[MultiAgentTask]: """根据执行 ID 获取任务记录""" return ( db.query(MultiAgentTask) .filter(MultiAgentTask.execution_id == execution_id) .all() ) @staticmethod def delete(db: Session, task_id: str) -> bool: """删除任务记录""" task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first() if task: db.delete(task) db.commit() return True return False class UserAgentCRUD: """用户智能体配置 CRUD 操作""" @staticmethod def create( db: Session, user_id: str, agent_name: str, agent_config: dict, ) -> UserAgent: """创建用户智能体配置""" agent = UserAgent( id=str(uuid.uuid4()), user_id=user_id, agent_name=agent_name, agent_config=agent_config, ) db.add(agent) db.commit() db.refresh(agent) return agent @staticmethod def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]: """根据 ID 获取配置""" return db.query(UserAgent).filter(UserAgent.id == agent_id).first() @staticmethod def get_by_user_id( db: Session, user_id: str, limit: int = 50 ) -> List[UserAgent]: """根据用户 ID 获取所有智能体配置""" return ( db.query(UserAgent) .filter(UserAgent.user_id == user_id) .order_by(UserAgent.created_at.desc()) .limit(limit) .all() ) @staticmethod def get_by_name( db: Session, user_id: str, agent_name: str ) -> List[UserAgent]: """根据用户 ID 和智能体名称获取配置""" return ( db.query(UserAgent) .filter( UserAgent.user_id == user_id, UserAgent.agent_name == agent_name, ) .all() ) @staticmethod def update_config( db: Session, agent_id: str, agent_config: dict ) -> Optional[UserAgent]: """更新智能体配置""" agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first() if agent: agent.agent_config = agent_config db.commit() db.refresh(agent) return agent @staticmethod def delete(db: Session, agent_id: str) -> bool: """删除智能体配置""" agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first() if agent: db.delete(agent) db.commit() return True return False @staticmethod def upsert( db: Session, user_id: str, agent_name: str, agent_config: dict, ) -> UserAgent: """更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性) 如果已存在相同 user_id 和 agent_name 的记录,则更新配置; 否则创建新记录。 """ existing = ( db.query(UserAgent) .filter( UserAgent.user_id == user_id, UserAgent.agent_name == agent_name, ) .first() ) if existing: # 更新现有记录 existing.agent_config = agent_config db.commit() db.refresh(existing) return existing else: # 创建新记录 agent = UserAgent( id=str(uuid.uuid4()), user_id=user_id, agent_name=agent_name, agent_config=agent_config, ) db.add(agent) db.commit() db.refresh(agent) return agent class ExportRecordCRUD: """导出记录 CRUD 操作""" @staticmethod def create( db: Session, task_id: str, user_id: str, export_type: str, file_name: str, file_path: str, file_url: str = "", file_size: int = 0, ) -> ExportRecord: """创建导出记录""" record = ExportRecord( task_id=task_id, user_id=user_id, export_type=export_type, file_name=file_name, file_path=file_path, file_url=file_url, file_size=file_size, ) db.add(record) db.commit() db.refresh(record) return record @staticmethod def get_by_id(db: Session, record_id: int) -> Optional[ExportRecord]: """根据 ID 获取记录""" return db.query(ExportRecord).filter(ExportRecord.id == record_id).first() @staticmethod def get_by_task_id( db: Session, task_id: str, limit: int = 50 ) -> List[ExportRecord]: """根据任务 ID 获取导出记录列表""" return ( db.query(ExportRecord) .filter(ExportRecord.task_id == task_id) .order_by(ExportRecord.created_at.desc()) .limit(limit) .all() ) @staticmethod def get_by_user_id( db: Session, user_id: str, limit: int = 50 ) -> List[ExportRecord]: """根据用户 ID 获取导出记录列表""" return ( db.query(ExportRecord) .filter(ExportRecord.user_id == user_id) .order_by(ExportRecord.created_at.desc()) .limit(limit) .all() ) @staticmethod def delete(db: Session, record_id: int) -> bool: """删除导出记录""" record = db.query(ExportRecord).filter(ExportRecord.id == record_id).first() if record: db.delete(record) db.commit() return True return False @staticmethod def delete_by_task_id(db: Session, task_id: str) -> bool: """删除任务的所有导出记录""" records = db.query(ExportRecord).filter(ExportRecord.task_id == task_id).all() if records: for record in records: db.delete(record) db.commit() return True return False class PlanShareCRUD: """任务分享 CRUD 操作""" @staticmethod def create( db: Session, share_token: str, task_id: str, task_data: dict, expires_at: Optional[datetime] = None, extraction_code: Optional[str] = None, ) -> PlanShare: """创建分享记录""" share = PlanShare( share_token=share_token, extraction_code=extraction_code, task_id=task_id, task_data=task_data, expires_at=expires_at, ) db.add(share) db.commit() db.refresh(share) return share @staticmethod def get_by_token(db: Session, share_token: str) -> Optional[PlanShare]: """根据 token 获取分享记录""" return db.query(PlanShare).filter(PlanShare.share_token == share_token).first() @staticmethod def get_by_task_id( db: Session, task_id: str, limit: int = 10 ) -> List[PlanShare]: """根据任务 ID 获取分享记录列表""" return ( db.query(PlanShare) .filter(PlanShare.task_id == task_id) .order_by(PlanShare.created_at.desc()) .limit(limit) .all() ) @staticmethod def increment_view_count(db: Session, share_token: str) -> Optional[PlanShare]: """增加查看次数""" share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first() if share: share.view_count = (share.view_count or 0) + 1 db.commit() db.refresh(share) return share @staticmethod def delete(db: Session, share_token: str) -> bool: """删除分享记录""" share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first() if share: db.delete(share) db.commit() return True return False @staticmethod def delete_by_task_id(db: Session, task_id: str) -> bool: """删除任务的所有分享记录""" shares = db.query(PlanShare).filter(PlanShare.task_id == task_id).all() if shares: for share in shares: db.delete(share) db.commit() return True return False