""" 数据库连接管理模块 使用 SQLAlchemy ORM,支持同步操作 """ import os import yaml from typing import Generator from contextlib import contextmanager import json from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.pool import QueuePool from sqlalchemy.dialects.postgresql import dialect as pg_dialect # 读取配置 yaml_file = os.path.join(os.getcwd(), "config", "config.yaml") try: with open(yaml_file, "r", encoding="utf-8") as file: config = yaml.safe_load(file).get("database", {}) except Exception: config = {} def get_database_url() -> str: """获取数据库连接 URL""" # 优先使用环境变量 host = os.getenv("DB_HOST", config.get("host", "localhost")) port = os.getenv("DB_PORT", config.get("port", "5432")) user = os.getenv("DB_USER", config.get("username", "postgres")) password = os.getenv("DB_PASSWORD", config.get("password", "")) dbname = os.getenv("DB_NAME", config.get("name", "agentcoord")) return f"postgresql://{user}:{password}@{host}:{port}/{dbname}" # 创建引擎 DATABASE_URL = get_database_url() engine = create_engine( DATABASE_URL, poolclass=QueuePool, pool_size=config.get("pool_size", 10), max_overflow=config.get("max_overflow", 20), pool_pre_ping=True, echo=False, # JSONB 类型处理器配置 json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s ) # 创建会话工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # 基础类 Base = declarative_base() def get_db() -> Generator: """ 获取数据库会话 用法: for db in get_db(): ... """ db = SessionLocal() try: yield db finally: db.close() @contextmanager def get_db_context() -> Generator: """ 上下文管理器方式获取数据库会话 用法: with get_db_context() as db: ... """ db = SessionLocal() try: yield db db.commit() except Exception as e: db.rollback() raise finally: db.close() def test_connection() -> bool: """测试数据库连接""" try: with engine.connect() as conn: conn.execute(text("SELECT 1")) return True except Exception as e: return False