feat:1.数据库存储功能添加(初版)2.后端REST API版本代码清理

This commit is contained in:
liailing1026
2026-02-25 10:55:51 +08:00
parent f736cd104a
commit 2140cfaf92
35 changed files with 3912 additions and 2981 deletions

95
backend/db/database.py Normal file
View File

@@ -0,0 +1,95 @@
"""
数据库连接管理模块
使用 SQLAlchemy ORM支持同步操作
"""
import os
import yaml
from typing import Generator
from contextlib import contextmanager
import json
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import QueuePool
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
# 读取配置
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
try:
with open(yaml_file, "r", encoding="utf-8") as file:
config = yaml.safe_load(file).get("database", {})
except Exception:
config = {}
def get_database_url() -> str:
"""获取数据库连接 URL"""
# 优先使用环境变量
host = os.getenv("DB_HOST", config.get("host", "localhost"))
port = os.getenv("DB_PORT", config.get("port", "5432"))
user = os.getenv("DB_USER", config.get("username", "postgres"))
password = os.getenv("DB_PASSWORD", config.get("password", ""))
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
# 创建引擎
DATABASE_URL = get_database_url()
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=config.get("pool_size", 10),
max_overflow=config.get("max_overflow", 20),
pool_pre_ping=True,
echo=False,
# JSONB 类型处理器配置
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 基础类
Base = declarative_base()
def get_db() -> Generator:
"""
获取数据库会话
用法: for db in get_db(): ...
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context() -> Generator:
"""
上下文管理器方式获取数据库会话
用法: with get_db_context() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
def test_connection() -> bool:
"""测试数据库连接"""
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
return False