init
This commit is contained in:
93
xiangj-adapter/internal/db/postgres.go
Normal file
93
xiangj-adapter/internal/db/postgres.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
type PostgresConfig struct {
|
||||
Address string
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
SSLMode string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime string
|
||||
ConnMaxIdleTime string
|
||||
}
|
||||
|
||||
func NewPostgres(cfg PostgresConfig) (*sql.DB, error) {
|
||||
if cfg.Address == "" {
|
||||
return nil, fmt.Errorf("database address is required")
|
||||
}
|
||||
if cfg.Username == "" {
|
||||
return nil, fmt.Errorf("database username is required")
|
||||
}
|
||||
if cfg.Database == "" {
|
||||
return nil, fmt.Errorf("database name is required")
|
||||
}
|
||||
|
||||
host, port := parseAddress(cfg.Address)
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("database address is invalid")
|
||||
}
|
||||
|
||||
connMaxLifetime, err := parseDuration(cfg.ConnMaxLifetime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conn_max_lifetime: %w", err)
|
||||
}
|
||||
connMaxIdleTime, err := parseDuration(cfg.ConnMaxIdleTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("conn_max_idle_time: %w", err)
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
|
||||
host, port, cfg.Username, cfg.Password, cfg.Database, cfg.SSLMode,
|
||||
)
|
||||
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open postgres: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func parseAddress(address string) (string, string) {
|
||||
address = strings.TrimSpace(address)
|
||||
if address == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if strings.Contains(address, ":") {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err == nil && host != "" && port != "" {
|
||||
return host, port
|
||||
}
|
||||
}
|
||||
|
||||
return address, "5432"
|
||||
}
|
||||
|
||||
func parseDuration(value string) (time.Duration, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return 0, nil
|
||||
}
|
||||
return time.ParseDuration(value)
|
||||
}
|
||||
Reference in New Issue
Block a user