feat:1.数据库存储功能添加(初版)2.后端REST API版本代码清理
This commit is contained in:
95
backend/db/database.py
Normal file
95
backend/db/database.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
数据库连接管理模块
|
||||
使用 SQLAlchemy ORM,支持同步操作
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from typing import Generator
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
|
||||
|
||||
# 读取配置
|
||||
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
||||
try:
|
||||
with open(yaml_file, "r", encoding="utf-8") as file:
|
||||
config = yaml.safe_load(file).get("database", {})
|
||||
except Exception:
|
||||
config = {}
|
||||
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""获取数据库连接 URL"""
|
||||
# 优先使用环境变量
|
||||
host = os.getenv("DB_HOST", config.get("host", "localhost"))
|
||||
port = os.getenv("DB_PORT", config.get("port", "5432"))
|
||||
user = os.getenv("DB_USER", config.get("username", "postgres"))
|
||||
password = os.getenv("DB_PASSWORD", config.get("password", ""))
|
||||
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
|
||||
|
||||
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
||||
|
||||
|
||||
# 创建引擎
|
||||
DATABASE_URL = get_database_url()
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
poolclass=QueuePool,
|
||||
pool_size=config.get("pool_size", 10),
|
||||
max_overflow=config.get("max_overflow", 20),
|
||||
pool_pre_ping=True,
|
||||
echo=False,
|
||||
# JSONB 类型处理器配置
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 基础类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
"""
|
||||
获取数据库会话
|
||||
用法: for db in get_db(): ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_context() -> Generator:
|
||||
"""
|
||||
上下文管理器方式获取数据库会话
|
||||
用法: with get_db_context() as db: ...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_connection() -> bool:
|
||||
"""测试数据库连接"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
Reference in New Issue
Block a user