feat:1.数据库存储功能添加(初版)2.后端REST API版本代码清理

This commit is contained in:
liailing1026
2026-02-25 10:55:51 +08:00
parent f736cd104a
commit 2140cfaf92
35 changed files with 3912 additions and 2981 deletions

25
backend/db/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""
AgentCoord 数据库模块
提供 PostgreSQL 数据库连接、模型和 CRUD 操作
基于 DATABASE_DESIGN.md 设计
"""
from .database import get_db, get_db_context, test_connection, engine, text
from .models import MultiAgentTask, UserAgent, TaskStatus
from .crud import MultiAgentTaskCRUD, UserAgentCRUD
__all__ = [
# 连接管理
"get_db",
"get_db_context",
"test_connection",
"engine",
"text",
# 模型
"MultiAgentTask",
"UserAgent",
"TaskStatus",
# CRUD
"MultiAgentTaskCRUD",
"UserAgentCRUD",
]

404
backend/db/crud.py Normal file
View File

@@ -0,0 +1,404 @@
"""
数据库 CRUD 操作
封装所有数据库操作方法 (基于 DATABASE_DESIGN.md)
"""
import copy
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
from .models import MultiAgentTask, UserAgent
class MultiAgentTaskCRUD:
"""多智能体任务 CRUD 操作"""
@staticmethod
def create(
db: Session,
task_id: Optional[str] = None, # 可选,如果为 None 则自动生成
user_id: str = "",
query: str = "",
agents_info: list = [],
task_outline: Optional[dict] = None,
assigned_agents: Optional[list] = None,
agent_scores: Optional[dict] = None,
result: Optional[str] = None,
) -> MultiAgentTask:
"""创建任务记录"""
task = MultiAgentTask(
task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的
user_id=user_id,
query=query,
agents_info=agents_info,
task_outline=task_outline,
assigned_agents=assigned_agents,
agent_scores=agent_scores,
result=result,
)
db.add(task)
db.commit()
db.refresh(task)
return task
@staticmethod
def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""根据任务 ID 获取记录"""
return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据用户 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.user_id == user_id)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_recent(
db: Session, limit: int = 20, offset: int = 0
) -> List[MultiAgentTask]:
"""获取最近的任务记录"""
return (
db.query(MultiAgentTask)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def update_result(
db: Session, task_id: str, result: list
) -> Optional[MultiAgentTask]:
"""更新任务结果"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.result = result if result else []
db.commit()
db.refresh(task)
return task
@staticmethod
def update_task_outline(
db: Session, task_id: str, task_outline: dict
) -> Optional[MultiAgentTask]:
"""更新任务大纲"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.task_outline = task_outline
db.commit()
db.refresh(task)
return task
@staticmethod
def update_assigned_agents(
db: Session, task_id: str, assigned_agents: dict
) -> Optional[MultiAgentTask]:
"""更新分配的智能体(步骤名 -> agent列表"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.assigned_agents = assigned_agents
db.commit()
db.refresh(task)
return task
@staticmethod
def update_agent_scores(
db: Session, task_id: str, agent_scores: dict
) -> Optional[MultiAgentTask]:
"""更新智能体评分(合并模式,追加新步骤的评分)"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
# 合并现有评分数据和新评分数据
existing_scores = task.agent_scores or {}
merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据
task.agent_scores = merged_scores
db.commit()
db.refresh(task)
return task
@staticmethod
def update_status(
db: Session, task_id: str, status: str
) -> Optional[MultiAgentTask]:
"""更新任务状态"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.status = status
db.commit()
db.refresh(task)
return task
@staticmethod
def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""增加任务执行次数"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_count = (task.execution_count or 0) + 1
db.commit()
db.refresh(task)
return task
@staticmethod
def update_generation_id(
db: Session, task_id: str, generation_id: str
) -> Optional[MultiAgentTask]:
"""更新生成 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.generation_id = generation_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_execution_id(
db: Session, task_id: str, execution_id: str
) -> Optional[MultiAgentTask]:
"""更新执行 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_id = execution_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_rehearsal_log(
db: Session, task_id: str, rehearsal_log: list
) -> Optional[MultiAgentTask]:
"""更新排练日志"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.rehearsal_log = rehearsal_log if rehearsal_log else []
db.commit()
db.refresh(task)
return task
@staticmethod
def append_rehearsal_log(
db: Session, task_id: str, log_entry: dict
) -> Optional[MultiAgentTask]:
"""追加排练日志条目"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
current_log = task.rehearsal_log or []
if isinstance(current_log, list):
current_log.append(log_entry)
else:
current_log = [log_entry]
task.rehearsal_log = current_log
db.commit()
db.refresh(task)
return task
@staticmethod
def update_branches(
db: Session, task_id: str, branches
) -> Optional[MultiAgentTask]:
"""更新任务分支数据
支持两种格式:
- list: 旧格式,直接覆盖
- dict: 新格式 { flow_branches: [...], task_process_branches: {...} }
两个 key 独立保存,互不干扰。
"""
import copy
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
if isinstance(branches, dict):
# 新格式:字典,独立保存两个 key互不干扰
# 使用深拷贝避免引用共享问题
existing = copy.deepcopy(task.branches) if task.branches else {}
if isinstance(existing, dict):
# 如果只更新 flow_branches保留已有的 task_process_branches
if 'flow_branches' in branches and 'task_process_branches' not in branches:
branches['task_process_branches'] = existing.get('task_process_branches', {})
# 如果只更新 task_process_branches保留已有的 flow_branches
if 'task_process_branches' in branches and 'flow_branches' not in branches:
branches['flow_branches'] = existing.get('flow_branches', [])
task.branches = branches
else:
# 旧格式:列表
task.branches = branches if branches else []
db.commit()
db.refresh(task)
return task
@staticmethod
def get_branches(db: Session, task_id: str) -> Optional[list]:
"""获取任务分支数据"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
return task.branches or []
return []
@staticmethod
def get_by_status(
db: Session, status: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据状态获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.status == status)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_by_generation_id(
db: Session, generation_id: str
) -> List[MultiAgentTask]:
"""根据生成 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.generation_id == generation_id)
.all()
)
@staticmethod
def get_by_execution_id(
db: Session, execution_id: str
) -> List[MultiAgentTask]:
"""根据执行 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.execution_id == execution_id)
.all()
)
@staticmethod
def delete(db: Session, task_id: str) -> bool:
"""删除任务记录"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
db.delete(task)
db.commit()
return True
return False
class UserAgentCRUD:
"""用户智能体配置 CRUD 操作"""
@staticmethod
def create(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""创建用户智能体配置"""
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent
@staticmethod
def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]:
"""根据 ID 获取配置"""
return db.query(UserAgent).filter(UserAgent.id == agent_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50
) -> List[UserAgent]:
"""根据用户 ID 获取所有智能体配置"""
return (
db.query(UserAgent)
.filter(UserAgent.user_id == user_id)
.order_by(UserAgent.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_by_name(
db: Session, user_id: str, agent_name: str
) -> List[UserAgent]:
"""根据用户 ID 和智能体名称获取配置"""
return (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.all()
)
@staticmethod
def update_config(
db: Session, agent_id: str, agent_config: dict
) -> Optional[UserAgent]:
"""更新智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
agent.agent_config = agent_config
db.commit()
db.refresh(agent)
return agent
@staticmethod
def delete(db: Session, agent_id: str) -> bool:
"""删除智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
db.delete(agent)
db.commit()
return True
return False
@staticmethod
def upsert(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性)
如果已存在相同 user_id 和 agent_name 的记录,则更新配置;
否则创建新记录。
"""
existing = (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.first()
)
if existing:
# 更新现有记录
existing.agent_config = agent_config
db.commit()
db.refresh(existing)
return existing
else:
# 创建新记录
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent

95
backend/db/database.py Normal file
View File

@@ -0,0 +1,95 @@
"""
数据库连接管理模块
使用 SQLAlchemy ORM支持同步操作
"""
import os
import yaml
from typing import Generator
from contextlib import contextmanager
import json
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import QueuePool
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
# 读取配置
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
try:
with open(yaml_file, "r", encoding="utf-8") as file:
config = yaml.safe_load(file).get("database", {})
except Exception:
config = {}
def get_database_url() -> str:
"""获取数据库连接 URL"""
# 优先使用环境变量
host = os.getenv("DB_HOST", config.get("host", "localhost"))
port = os.getenv("DB_PORT", config.get("port", "5432"))
user = os.getenv("DB_USER", config.get("username", "postgres"))
password = os.getenv("DB_PASSWORD", config.get("password", ""))
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
# 创建引擎
DATABASE_URL = get_database_url()
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=config.get("pool_size", 10),
max_overflow=config.get("max_overflow", 20),
pool_pre_ping=True,
echo=False,
# JSONB 类型处理器配置
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 基础类
Base = declarative_base()
def get_db() -> Generator:
"""
获取数据库会话
用法: for db in get_db(): ...
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context() -> Generator:
"""
上下文管理器方式获取数据库会话
用法: with get_db_context() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
def test_connection() -> bool:
"""测试数据库连接"""
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
return False

22
backend/db/init_db.py Normal file
View File

@@ -0,0 +1,22 @@
"""
数据库初始化脚本
运行此脚本创建所有表结构
基于 DATABASE_DESIGN.md 设计
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from db.database import engine, Base
from db.models import MultiAgentTask, UserAgent
def init_database():
"""初始化数据库表结构"""
Base.metadata.create_all(bind=engine)
if __name__ == "__main__":
init_database()

104
backend/db/models.py Normal file
View File

@@ -0,0 +1,104 @@
"""
SQLAlchemy ORM 数据模型
对应数据库表结构 (基于 DATABASE_DESIGN.md)
"""
import uuid
from datetime import datetime, timezone
from enum import Enum as PyEnum
from sqlalchemy import Column, String, Text, DateTime, Integer, Enum, Index, ForeignKey
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from .database import Base
class TaskStatus(str, PyEnum):
"""任务状态枚举"""
GENERATING = "generating" # 生成中 - TaskProcess 生成阶段
EXECUTING = "executing" # 执行中 - 任务执行阶段
STOPPED = "stopped" # 已停止 - 用户手动停止执行
COMPLETED = "completed" # 已完成 - 任务正常完成
def utc_now():
"""获取当前 UTC 时间"""
return datetime.now(timezone.utc)
class MultiAgentTask(Base):
"""多智能体任务记录模型"""
__tablename__ = "multi_agent_tasks"
task_id = Column(String(64), primary_key=True)
user_id = Column(String(64), nullable=False, index=True)
query = Column(Text, nullable=False)
agents_info = Column(JSONB, nullable=False)
task_outline = Column(JSONB)
assigned_agents = Column(JSONB)
agent_scores = Column(JSONB)
result = Column(JSONB)
status = Column(
Enum(TaskStatus, name="task_status_enum", create_type=False),
default=TaskStatus.GENERATING,
nullable=False
)
execution_count = Column(Integer, default=0, nullable=False)
generation_id = Column(String(64))
execution_id = Column(String(64))
rehearsal_log = Column(JSONB)
branches = Column(JSONB) # 任务大纲探索分支数据
created_at = Column(DateTime(timezone=True), default=utc_now)
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
__table_args__ = (
Index("idx_multi_agent_tasks_status", "status"),
Index("idx_multi_agent_tasks_generation_id", "generation_id"),
Index("idx_multi_agent_tasks_execution_id", "execution_id"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"task_id": self.task_id,
"user_id": self.user_id,
"query": self.query,
"agents_info": self.agents_info,
"task_outline": self.task_outline,
"assigned_agents": self.assigned_agents,
"agent_scores": self.agent_scores,
"result": self.result,
"status": self.status.value if self.status else None,
"execution_count": self.execution_count,
"generation_id": self.generation_id,
"execution_id": self.execution_id,
"rehearsal_log": self.rehearsal_log,
"branches": self.branches,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class UserAgent(Base):
"""用户保存的智能体配置模型 (可选表)"""
__tablename__ = "user_agents"
id = Column(String(64), primary_key=True)
user_id = Column(String(64), nullable=False, index=True)
agent_name = Column(String(100), nullable=False)
agent_config = Column(JSONB, nullable=False)
created_at = Column(DateTime(timezone=True), default=utc_now)
__table_args__ = (
Index("idx_user_agents_user_created", "user_id", "created_at"),
)
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"user_id": self.user_id,
"agent_name": self.agent_name,
"agent_config": self.agent_config,
"created_at": self.created_at.isoformat() if self.created_at else None,
}

70
backend/db/schema.sql Normal file
View File

@@ -0,0 +1,70 @@
-- AgentCoord 数据库表结构
-- 基于 DATABASE_DESIGN.md 设计
-- 执行方式: psql -U postgres -d agentcoord -f schema.sql
-- =============================================================================
-- 表1: multi_agent_tasks (多智能体任务记录)
-- 状态枚举: pending/planning/generating/executing/completed/failed
-- =============================================================================
CREATE TABLE IF NOT EXISTS multi_agent_tasks (
task_id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
query TEXT NOT NULL,
agents_info JSONB NOT NULL,
task_outline JSONB,
assigned_agents JSONB,
agent_scores JSONB,
result JSONB,
status VARCHAR(20) DEFAULT 'pending',
execution_count INTEGER DEFAULT 0,
generation_id VARCHAR(64),
execution_id VARCHAR(64),
rehearsal_log JSONB,
branches JSONB, -- 任务大纲探索分支数据
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_user_id ON multi_agent_tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_created_at ON multi_agent_tasks(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_status ON multi_agent_tasks(status);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_generation_id ON multi_agent_tasks(generation_id);
CREATE INDEX IF NOT EXISTS idx_multi_agent_tasks_execution_id ON multi_agent_tasks(execution_id);
-- =============================================================================
-- 表2: user_agents (用户保存的智能体配置) - 可选表
-- =============================================================================
CREATE TABLE IF NOT EXISTS user_agents (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
agent_name VARCHAR(100) NOT NULL,
agent_config JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_user_agents_user_id ON user_agents(user_id);
-- =============================================================================
-- 更新时间触发器函数
-- =============================================================================
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
-- 为 multi_agent_tasks 表创建触发器
CREATE TRIGGER update_multi_agent_tasks_updated_at
BEFORE UPDATE ON multi_agent_tasks
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
DO $$
BEGIN
RAISE NOTICE '✅ PostgreSQL 数据库表结构创建完成!';
RAISE NOTICE '表: multi_agent_tasks (多智能体任务记录)';
RAISE NOTICE '表: user_agents (用户智能体配置)';
END $$;