Files
AgentCoord/backend/db/models.py
2026-03-12 13:35:04 +08:00

171 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SQLAlchemy ORM 数据模型
对应数据库表结构 (基于 DATABASE_DESIGN.md)
"""
import uuid
from datetime import datetime, timezone
from enum import Enum as PyEnum
from sqlalchemy import Column, String, Text, DateTime, Integer, Enum, Index, ForeignKey, Boolean
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from .database import Base
class TaskStatus(str, PyEnum):
"""任务状态枚举"""
GENERATING = "generating" # 生成中 - TaskProcess 生成阶段
EXECUTING = "executing" # 执行中 - 任务执行阶段
STOPPED = "stopped" # 已停止 - 用户手动停止执行
COMPLETED = "completed" # 已完成 - 任务正常完成
def utc_now():
"""获取当前 UTC 时间"""
return datetime.now(timezone.utc)
class MultiAgentTask(Base):
"""多智能体任务记录模型"""
__tablename__ = "multi_agent_tasks"
task_id = Column(String(64), primary_key=True)
user_id = Column(String(64), nullable=False, index=True)
query = Column(Text, nullable=False)
agents_info = Column(JSONB, nullable=False)
task_outline = Column(JSONB)
assigned_agents = Column(JSONB)
agent_scores = Column(JSONB)
result = Column(JSONB)
status = Column(
Enum(TaskStatus, name="task_status_enum", create_type=False),
default=TaskStatus.GENERATING,
nullable=False
)
execution_count = Column(Integer, default=0, nullable=False)
generation_id = Column(String(64))
execution_id = Column(String(64))
rehearsal_log = Column(JSONB)
branches = Column(JSONB) # 任务大纲探索分支数据
is_pinned = Column(Boolean, default=False, nullable=False) # 置顶标志
created_at = Column(DateTime(timezone=True), default=utc_now)
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
__table_args__ = (
Index("idx_multi_agent_tasks_status", "status"),
Index("idx_multi_agent_tasks_generation_id", "generation_id"),
Index("idx_multi_agent_tasks_execution_id", "execution_id"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"task_id": self.task_id,
"user_id": self.user_id,
"query": self.query,
"agents_info": self.agents_info,
"task_outline": self.task_outline,
"assigned_agents": self.assigned_agents,
"agent_scores": self.agent_scores,
"result": self.result,
"status": self.status.value if self.status else None,
"execution_count": self.execution_count,
"generation_id": self.generation_id,
"execution_id": self.execution_id,
"rehearsal_log": self.rehearsal_log,
"branches": self.branches,
"is_pinned": self.is_pinned,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class ExportRecord(Base):
"""导出记录模型"""
__tablename__ = "export_records"
id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(String(64), nullable=False, index=True) # 关联任务ID
user_id = Column(String(64), nullable=False, index=True) # 用户ID
export_type = Column(String(32), nullable=False) # 导出类型: doc/markdown/mindmap/infographic/excel/ppt
file_name = Column(String(256), nullable=False) # 文件名
file_path = Column(String(512), nullable=False) # 服务器存储路径
file_url = Column(String(512)) # 访问URL
file_size = Column(Integer, default=0) # 文件大小(字节)
created_at = Column(DateTime(timezone=True), default=utc_now)
__table_args__ = (
Index("idx_export_records_task_user", "task_id", "user_id"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"task_id": self.task_id,
"user_id": self.user_id,
"export_type": self.export_type,
"file_name": self.file_name,
"file_path": self.file_path,
"file_url": self.file_url,
"file_size": self.file_size,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class UserAgent(Base):
"""用户保存的智能体配置模型 (可选表)"""
__tablename__ = "user_agents"
id = Column(String(64), primary_key=True)
user_id = Column(String(64), nullable=False, index=True)
agent_name = Column(String(100), nullable=False)
agent_config = Column(JSONB, nullable=False)
created_at = Column(DateTime(timezone=True), default=utc_now)
__table_args__ = (
Index("idx_user_agents_user_created", "user_id", "created_at"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"user_id": self.user_id,
"agent_name": self.agent_name,
"agent_config": self.agent_config,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class PlanShare(Base):
"""任务分享记录模型"""
__tablename__ = "plan_shares"
id = Column(Integer, primary_key=True, autoincrement=True)
share_token = Column(String(64), unique=True, index=True, nullable=False) # 唯一分享码
extraction_code = Column(String(8), nullable=True) # 提取码4位字母数字
task_id = Column(String(64), nullable=False, index=True) # 关联的任务ID
task_data = Column(JSONB, nullable=False) # 完整的任务数据(脱敏后)
created_at = Column(DateTime(timezone=True), default=utc_now)
expires_at = Column(DateTime(timezone=True), nullable=True) # 过期时间
view_count = Column(Integer, default=0) # 查看次数
__table_args__ = (
Index("idx_plan_shares_token", "share_token"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"share_token": self.share_token,
"extraction_code": self.extraction_code,
"task_id": self.task_id,
"task_data": self.task_data,
"created_at": self.created_at.isoformat() if self.created_at else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"view_count": self.view_count,
}