96 lines
2.4 KiB
Python
96 lines
2.4 KiB
Python
"""
|
||
数据库连接管理模块
|
||
使用 SQLAlchemy ORM,支持同步操作
|
||
"""
|
||
|
||
import os
|
||
import yaml
|
||
from typing import Generator
|
||
from contextlib import contextmanager
|
||
import json
|
||
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||
from sqlalchemy.pool import QueuePool
|
||
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
|
||
|
||
# 读取配置
|
||
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
||
try:
|
||
with open(yaml_file, "r", encoding="utf-8") as file:
|
||
config = yaml.safe_load(file).get("database", {})
|
||
except Exception:
|
||
config = {}
|
||
|
||
|
||
def get_database_url() -> str:
|
||
"""获取数据库连接 URL"""
|
||
# 优先使用环境变量
|
||
host = os.getenv("DB_HOST", config.get("host", "localhost"))
|
||
port = os.getenv("DB_PORT", config.get("port", "5432"))
|
||
user = os.getenv("DB_USER", config.get("username", "postgres"))
|
||
password = os.getenv("DB_PASSWORD", config.get("password", ""))
|
||
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
|
||
|
||
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
||
|
||
|
||
# 创建引擎
|
||
DATABASE_URL = get_database_url()
|
||
engine = create_engine(
|
||
DATABASE_URL,
|
||
poolclass=QueuePool,
|
||
pool_size=config.get("pool_size", 10),
|
||
max_overflow=config.get("max_overflow", 20),
|
||
pool_pre_ping=True,
|
||
echo=False,
|
||
# JSONB 类型处理器配置
|
||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
|
||
)
|
||
|
||
# 创建会话工厂
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
||
# 基础类
|
||
Base = declarative_base()
|
||
|
||
|
||
def get_db() -> Generator:
|
||
"""
|
||
获取数据库会话
|
||
用法: for db in get_db(): ...
|
||
"""
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@contextmanager
|
||
def get_db_context() -> Generator:
|
||
"""
|
||
上下文管理器方式获取数据库会话
|
||
用法: with get_db_context() as db: ...
|
||
"""
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
db.commit()
|
||
except Exception as e:
|
||
db.rollback()
|
||
raise
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def test_connection() -> bool:
|
||
"""测试数据库连接"""
|
||
try:
|
||
with engine.connect() as conn:
|
||
conn.execute(text("SELECT 1"))
|
||
return True
|
||
except Exception as e:
|
||
return False
|