Files
AgentCoord/backend/db/database.py

96 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数据库连接管理模块
使用 SQLAlchemy ORM支持同步操作
"""
import os
import yaml
from typing import Generator
from contextlib import contextmanager
import json
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import QueuePool
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
# 读取配置
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
try:
with open(yaml_file, "r", encoding="utf-8") as file:
config = yaml.safe_load(file).get("database", {})
except Exception:
config = {}
def get_database_url() -> str:
"""获取数据库连接 URL"""
# 优先使用环境变量
host = os.getenv("DB_HOST", config.get("host", "localhost"))
port = os.getenv("DB_PORT", config.get("port", "5432"))
user = os.getenv("DB_USER", config.get("username", "postgres"))
password = os.getenv("DB_PASSWORD", config.get("password", ""))
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
# 创建引擎
DATABASE_URL = get_database_url()
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=config.get("pool_size", 10),
max_overflow=config.get("max_overflow", 20),
pool_pre_ping=True,
echo=False,
# JSONB 类型处理器配置
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 基础类
Base = declarative_base()
def get_db() -> Generator:
"""
获取数据库会话
用法: for db in get_db(): ...
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context() -> Generator:
"""
上下文管理器方式获取数据库会话
用法: with get_db_context() as db: ...
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
def test_connection() -> bool:
"""测试数据库连接"""
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
return False