This commit is contained in:
madaoxs
2026-01-13 17:55:36 +08:00
commit e348e845f2
14 changed files with 700 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
package apperror
import "github.com/gofiber/fiber/v2"
type AppError struct {
Status int
Message string
}
func (e *AppError) Error() string {
return e.Message
}
func BadRequest(message string) *AppError {
return &AppError{Status: fiber.StatusBadRequest, Message: message}
}
func Internal(message string) *AppError {
return &AppError{Status: fiber.StatusInternalServerError, Message: message}
}

View File

@@ -0,0 +1,88 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Log LogConfig `yaml:"log"`
}
type ServerConfig struct {
Address string `yaml:"address"`
}
type DatabaseConfig struct {
Address string `yaml:"address"`
Username string `yaml:"username"`
Password string `yaml:"password"`
Database string `yaml:"database"`
Schema string `yaml:"schema"`
SSLMode string `yaml:"sslmode"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
}
type LogConfig struct {
Level string `yaml:"level"`
Output string `yaml:"output"`
FilePath string `yaml:"file_path"`
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
applyDefaults(&cfg)
return &cfg, nil
}
func applyDefaults(cfg *Config) {
if cfg.Server.Address == "" {
cfg.Server.Address = ":21056"
}
if cfg.Log.Level == "" {
cfg.Log.Level = "info"
}
if cfg.Log.Output == "" {
cfg.Log.Output = "console"
}
if cfg.Log.FilePath == "" {
cfg.Log.FilePath = "/app/log/app.log"
}
if cfg.Database.MaxOpenConns == 0 {
cfg.Database.MaxOpenConns = 10
}
if cfg.Database.MaxIdleConns == 0 {
cfg.Database.MaxIdleConns = 5
}
if cfg.Database.Address == "" {
cfg.Database.Address = "localhost:5432"
}
if cfg.Database.Schema == "" {
cfg.Database.Schema = "public"
}
if cfg.Database.SSLMode == "" {
cfg.Database.SSLMode = "disable"
}
if cfg.Database.ConnMaxLifetime == "" {
cfg.Database.ConnMaxLifetime = "30m"
}
if cfg.Database.ConnMaxIdleTime == "" {
cfg.Database.ConnMaxIdleTime = "5m"
}
}

View File

@@ -0,0 +1,106 @@
package controller
import (
"log/slog"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"xiangj-adapter/internal/apperror"
"xiangj-adapter/internal/model"
"xiangj-adapter/internal/service"
)
type DataController struct {
service *service.DataService
log *slog.Logger
}
func NewDataController(service *service.DataService, log *slog.Logger) *DataController {
return &DataController{service: service, log: log}
}
func (c *DataController) GetData(ctx *fiber.Ctx) error {
op := strings.ToLower(ctx.Query("op"))
if op == "" {
op = "select"
}
switch op {
case "select":
return c.handleSelect(ctx)
default:
return apperror.BadRequest("op not supported")
}
}
func (c *DataController) handleSelect(ctx *fiber.Ctx) error {
doid := ctx.Query("doid")
table := strings.ToLower(lastSegment(doid))
if table == "" {
return apperror.BadRequest("doid is required")
}
offset, err := parseNonNegativeIntDefault(ctx.Query("offset"), "offset", 0)
if err != nil {
return err
}
count, err := parsePositiveIntDefault(ctx.Query("count"), "count", 100)
if err != nil {
return err
}
data, err := c.service.Select(ctx.Context(), table, offset, count)
if err != nil {
c.log.Error("select failed", "table", table, "err", err)
return apperror.Internal("query failed")
}
return ctx.JSON(model.APIResponse{Data: data, Error: ""})
}
func lastSegment(input string) string {
trimmed := strings.Trim(input, "/")
if trimmed == "" {
return ""
}
parts := strings.Split(trimmed, "/")
return parts[len(parts)-1]
}
func parseNonNegativeInt(value, field string) (int, error) {
if value == "" {
return 0, apperror.BadRequest(field + " is required")
}
parsed, err := strconv.Atoi(value)
if err != nil || parsed < 0 {
return 0, apperror.BadRequest(field + " must be a non-negative integer")
}
return parsed, nil
}
func parsePositiveInt(value, field string) (int, error) {
if value == "" {
return 0, apperror.BadRequest(field + " is required")
}
parsed, err := strconv.Atoi(value)
if err != nil || parsed <= 0 {
return 0, apperror.BadRequest(field + " must be a positive integer")
}
return parsed, nil
}
func parseNonNegativeIntDefault(value, field string, defaultValue int) (int, error) {
if value == "" {
return defaultValue, nil
}
return parseNonNegativeInt(value, field)
}
func parsePositiveIntDefault(value, field string, defaultValue int) (int, error) {
if value == "" {
return defaultValue, nil
}
return parsePositiveInt(value, field)
}

View File

@@ -0,0 +1,93 @@
package db
import (
"database/sql"
"fmt"
"net"
"strings"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
)
type PostgresConfig struct {
Address string
Username string
Password string
Database string
SSLMode string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime string
ConnMaxIdleTime string
}
func NewPostgres(cfg PostgresConfig) (*sql.DB, error) {
if cfg.Address == "" {
return nil, fmt.Errorf("database address is required")
}
if cfg.Username == "" {
return nil, fmt.Errorf("database username is required")
}
if cfg.Database == "" {
return nil, fmt.Errorf("database name is required")
}
host, port := parseAddress(cfg.Address)
if host == "" {
return nil, fmt.Errorf("database address is invalid")
}
connMaxLifetime, err := parseDuration(cfg.ConnMaxLifetime)
if err != nil {
return nil, fmt.Errorf("conn_max_lifetime: %w", err)
}
connMaxIdleTime, err := parseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return nil, fmt.Errorf("conn_max_idle_time: %w", err)
}
dsn := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
host, port, cfg.Username, cfg.Password, cfg.Database, cfg.SSLMode,
)
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, fmt.Errorf("open postgres: %w", err)
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(connMaxLifetime)
db.SetConnMaxIdleTime(connMaxIdleTime)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping postgres: %w", err)
}
return db, nil
}
func parseAddress(address string) (string, string) {
address = strings.TrimSpace(address)
if address == "" {
return "", ""
}
if strings.Contains(address, ":") {
host, port, err := net.SplitHostPort(address)
if err == nil && host != "" && port != "" {
return host, port
}
}
return address, "5432"
}
func parseDuration(value string) (time.Duration, error) {
if strings.TrimSpace(value) == "" {
return 0, nil
}
return time.ParseDuration(value)
}

View File

@@ -0,0 +1,44 @@
package logger
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
)
type Options struct {
Level string
Output string
FilePath string
}
func New(opts Options) (*slog.Logger, error) {
var lvl slog.Level
if err := lvl.UnmarshalText([]byte(opts.Level)); err != nil {
lvl = slog.LevelInfo
}
output := strings.ToLower(strings.TrimSpace(opts.Output))
switch output {
case "", "console":
handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})
return slog.New(handler), nil
case "file":
if strings.TrimSpace(opts.FilePath) == "" {
return nil, fmt.Errorf("log file path is required")
}
if err := os.MkdirAll(filepath.Dir(opts.FilePath), 0o755); err != nil {
return nil, fmt.Errorf("create log directory: %w", err)
}
file, err := os.OpenFile(opts.FilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
handler := slog.NewTextHandler(file, &slog.HandlerOptions{Level: lvl})
return slog.New(handler), nil
default:
return nil, fmt.Errorf("unsupported log output: %s", opts.Output)
}
}

View File

@@ -0,0 +1,6 @@
package model
type APIResponse struct {
Data any `json:"data"`
Error string `json:"error"`
}

View File

@@ -0,0 +1,11 @@
package router
import (
"github.com/gofiber/fiber/v2"
"xiangj-adapter/internal/controller"
)
func Register(app *fiber.App, dataController *controller.DataController) {
app.Get("/data", dataController.GetData)
}

View File

@@ -0,0 +1,79 @@
package service
import (
"context"
"database/sql"
"fmt"
"regexp"
)
var tableNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
type DataService struct {
db *sql.DB
schema string
}
func NewDataService(db *sql.DB, schema string) *DataService {
return &DataService{db: db, schema: schema}
}
func (s *DataService) Select(ctx context.Context, table string, offset, count int) ([]map[string]any, error) {
if !tableNameRegex.MatchString(table) {
return nil, fmt.Errorf("invalid table name")
}
if s.schema != "" && !tableNameRegex.MatchString(s.schema) {
return nil, fmt.Errorf("invalid schema name")
}
qualifiedTable := table
if s.schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", s.schema, table)
}
query := fmt.Sprintf("SELECT * FROM %s LIMIT $2 OFFSET $1", qualifiedTable)
rows, err := s.db.QueryContext(ctx, query, offset, count)
if err != nil {
return nil, fmt.Errorf("query table %s: %w", table, err)
}
defer rows.Close()
return scanRows(rows)
}
func scanRows(rows *sql.Rows) ([]map[string]any, error) {
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("read columns: %w", err)
}
results := make([]map[string]any, 0)
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
row := make(map[string]any, len(columns))
for i, col := range columns {
val := values[i]
if b, ok := val.([]byte); ok {
row[col] = string(b)
continue
}
row[col] = val
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate rows: %w", err)
}
return results, nil
}