This commit is contained in:
madaoxs
2026-01-13 17:55:36 +08:00
commit e348e845f2
14 changed files with 700 additions and 0 deletions

View File

@@ -0,0 +1,93 @@
package db
import (
"database/sql"
"fmt"
"net"
"strings"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
)
type PostgresConfig struct {
Address string
Username string
Password string
Database string
SSLMode string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime string
ConnMaxIdleTime string
}
func NewPostgres(cfg PostgresConfig) (*sql.DB, error) {
if cfg.Address == "" {
return nil, fmt.Errorf("database address is required")
}
if cfg.Username == "" {
return nil, fmt.Errorf("database username is required")
}
if cfg.Database == "" {
return nil, fmt.Errorf("database name is required")
}
host, port := parseAddress(cfg.Address)
if host == "" {
return nil, fmt.Errorf("database address is invalid")
}
connMaxLifetime, err := parseDuration(cfg.ConnMaxLifetime)
if err != nil {
return nil, fmt.Errorf("conn_max_lifetime: %w", err)
}
connMaxIdleTime, err := parseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return nil, fmt.Errorf("conn_max_idle_time: %w", err)
}
dsn := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
host, port, cfg.Username, cfg.Password, cfg.Database, cfg.SSLMode,
)
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, fmt.Errorf("open postgres: %w", err)
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(connMaxLifetime)
db.SetConnMaxIdleTime(connMaxIdleTime)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping postgres: %w", err)
}
return db, nil
}
func parseAddress(address string) (string, string) {
address = strings.TrimSpace(address)
if address == "" {
return "", ""
}
if strings.Contains(address, ":") {
host, port, err := net.SplitHostPort(address)
if err == nil && host != "" && port != "" {
return host, port
}
}
return address, "5432"
}
func parseDuration(value string) (time.Duration, error) {
if strings.TrimSpace(value) == "" {
return 0, nil
}
return time.ParseDuration(value)
}