94 lines
2.0 KiB
Go
94 lines
2.0 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
)
|
|
|
|
type PostgresConfig struct {
|
|
Address string
|
|
Username string
|
|
Password string
|
|
Database string
|
|
SSLMode string
|
|
MaxOpenConns int
|
|
MaxIdleConns int
|
|
ConnMaxLifetime string
|
|
ConnMaxIdleTime string
|
|
}
|
|
|
|
func NewPostgres(cfg PostgresConfig) (*sql.DB, error) {
|
|
if cfg.Address == "" {
|
|
return nil, fmt.Errorf("database address is required")
|
|
}
|
|
if cfg.Username == "" {
|
|
return nil, fmt.Errorf("database username is required")
|
|
}
|
|
if cfg.Database == "" {
|
|
return nil, fmt.Errorf("database name is required")
|
|
}
|
|
|
|
host, port := parseAddress(cfg.Address)
|
|
if host == "" {
|
|
return nil, fmt.Errorf("database address is invalid")
|
|
}
|
|
|
|
connMaxLifetime, err := parseDuration(cfg.ConnMaxLifetime)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("conn_max_lifetime: %w", err)
|
|
}
|
|
connMaxIdleTime, err := parseDuration(cfg.ConnMaxIdleTime)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("conn_max_idle_time: %w", err)
|
|
}
|
|
|
|
dsn := fmt.Sprintf(
|
|
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
|
|
host, port, cfg.Username, cfg.Password, cfg.Database, cfg.SSLMode,
|
|
)
|
|
|
|
db, err := sql.Open("pgx", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open postgres: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
db.SetConnMaxLifetime(connMaxLifetime)
|
|
db.SetConnMaxIdleTime(connMaxIdleTime)
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("ping postgres: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func parseAddress(address string) (string, string) {
|
|
address = strings.TrimSpace(address)
|
|
if address == "" {
|
|
return "", ""
|
|
}
|
|
|
|
if strings.Contains(address, ":") {
|
|
host, port, err := net.SplitHostPort(address)
|
|
if err == nil && host != "" && port != "" {
|
|
return host, port
|
|
}
|
|
}
|
|
|
|
return address, "5432"
|
|
}
|
|
|
|
func parseDuration(value string) (time.Duration, error) {
|
|
if strings.TrimSpace(value) == "" {
|
|
return 0, nil
|
|
}
|
|
return time.ParseDuration(value)
|
|
}
|