package persistence import ( "context" "database/sql" "fmt" "testing" "time" _ "github.com/lib/pq" "go.yandata.net/iod/iod/go-trustlog/api/logger" ) const ( postgresHost = "localhost" postgresPort = 5432 postgresUser = "postgres" postgresPassword = "postgres" postgresDatabase = "trustlog" ) // setupPostgresDB 创建 PostgreSQL 测试数据库连接 func setupPostgresDB(t *testing.T) (*sql.DB, bool) { dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", postgresHost, postgresPort, postgresUser, postgresPassword, postgresDatabase) db, err := sql.Open("postgres", dsn) if err != nil { t.Logf("Failed to connect to PostgreSQL: %v (skipping)", err) return nil, false } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() err = db.PingContext(ctx) if err != nil { t.Logf("PostgreSQL not available: %v (skipping)", err) db.Close() return nil, false } // 初始化表结构 opDDL, cursorDDL, retryDDL, err := GetDialectDDL("postgres") if err != nil { t.Fatalf("Failed to get DDL: %v", err) } // 删除已存在的表(测试环境) _, _ = db.Exec("DROP TABLE IF EXISTS operation CASCADE") _, _ = db.Exec("DROP TABLE IF EXISTS trustlog_cursor CASCADE") _, _ = db.Exec("DROP TABLE IF EXISTS trustlog_retry CASCADE") // 创建表 if _, err := db.Exec(opDDL); err != nil { t.Fatalf("Failed to create operation table: %v", err) } if _, err := db.Exec(cursorDDL); err != nil { t.Fatalf("Failed to create cursor table: %v", err) } if _, err := db.Exec(retryDDL); err != nil { t.Fatalf("Failed to create retry table: %v", err) } return db, true } // TestPostgreSQL_Basic 测试 PostgreSQL 基本操作 func TestPostgreSQL_Basic(t *testing.T) { if testing.Short() { t.Skip("Skipping PostgreSQL integration test in short mode") } db, ok := setupPostgresDB(t) if !ok { t.Skip("PostgreSQL not available") return } defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() // 创建 Repository repo := NewOperationRepository(db, log) // 创建测试操作 op := createTestOperation(t, fmt.Sprintf("pg-test-%d", time.Now().Unix())) clientIP := "192.168.1.100" serverIP := "10.0.0.1" op.ClientIP = &clientIP op.ServerIP = &serverIP // 保存操作 err := repo.Save(ctx, op, StatusNotTrustlogged) if err != nil { t.Fatalf("Failed to save operation: %v", err) } t.Logf("✅ Saved operation: %s", op.OpID) // 验证保存 savedOp, status, err := repo.FindByID(ctx, op.OpID) if err != nil { t.Fatalf("Failed to find operation: %v", err) } if savedOp.OpID != op.OpID { t.Errorf("Expected OpID %s, got %s", op.OpID, savedOp.OpID) } if status != StatusNotTrustlogged { t.Errorf("Expected status NOT_TRUSTLOGGED, got %v", status) } if savedOp.ClientIP == nil || *savedOp.ClientIP != clientIP { t.Error("ClientIP not saved correctly") } if savedOp.ServerIP == nil || *savedOp.ServerIP != serverIP { t.Error("ServerIP not saved correctly") } t.Logf("✅ Verified operation in PostgreSQL") // 更新状态 err = repo.UpdateStatus(ctx, op.OpID, StatusTrustlogged) if err != nil { t.Fatalf("Failed to update status: %v", err) } // 验证更新 _, status, err = repo.FindByID(ctx, op.OpID) if err != nil { t.Fatalf("Failed to find operation after update: %v", err) } if status != StatusTrustlogged { t.Errorf("Expected status TRUSTLOGGED, got %v", status) } t.Logf("✅ PostgreSQL integration test passed") } // TestPostgreSQL_Transaction 测试 PostgreSQL 事务 func TestPostgreSQL_Transaction(t *testing.T) { if testing.Short() { t.Skip("Skipping PostgreSQL integration test in short mode") } db, ok := setupPostgresDB(t) if !ok { t.Skip("PostgreSQL not available") return } defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) // 测试事务提交 tx, err := db.BeginTx(ctx, nil) if err != nil { t.Fatalf("Failed to begin transaction: %v", err) } op1 := createTestOperation(t, fmt.Sprintf("pg-tx-commit-%d", time.Now().Unix())) err = repo.SaveTx(ctx, tx, op1, StatusNotTrustlogged) if err != nil { tx.Rollback() t.Fatalf("Failed to save in transaction: %v", err) } err = tx.Commit() if err != nil { t.Fatalf("Failed to commit transaction: %v", err) } // 验证已提交 _, _, err = repo.FindByID(ctx, op1.OpID) if err != nil { t.Errorf("Operation should exist after commit: %v", err) } t.Logf("✅ Transaction commit tested") // 测试事务回滚 tx, err = db.BeginTx(ctx, nil) if err != nil { t.Fatalf("Failed to begin transaction: %v", err) } op2 := createTestOperation(t, fmt.Sprintf("pg-tx-rollback-%d", time.Now().Unix())) err = repo.SaveTx(ctx, tx, op2, StatusNotTrustlogged) if err != nil { tx.Rollback() t.Fatalf("Failed to save in transaction: %v", err) } err = tx.Rollback() if err != nil { t.Fatalf("Failed to rollback transaction: %v", err) } // 验证已回滚 _, _, err = repo.FindByID(ctx, op2.OpID) if err == nil { t.Error("Operation should not exist after rollback") } t.Logf("✅ Transaction rollback tested") t.Logf("✅ PostgreSQL transaction test passed") } // TestPostgreSQL_CursorOperations 测试 PostgreSQL 游标操作 func TestPostgreSQL_CursorOperations(t *testing.T) { if testing.Short() { t.Skip("Skipping PostgreSQL integration test in short mode") } db, ok := setupPostgresDB(t) if !ok { t.Skip("PostgreSQL not available") return } defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() cursorRepo := NewCursorRepository(db, log) cursorKey := "pg-test-cursor" initialValue := time.Now().Format(time.RFC3339Nano) // 初始化游标 err := cursorRepo.InitCursor(ctx, cursorKey, initialValue) if err != nil { t.Fatalf("Failed to init cursor: %v", err) } // 读取游标 value, err := cursorRepo.GetCursor(ctx, cursorKey) if err != nil { t.Fatalf("Failed to get cursor: %v", err) } if value != initialValue { t.Errorf("Expected cursor value %s, got %s", initialValue, value) } // 更新游标 newValue := time.Now().Add(1 * time.Hour).Format(time.RFC3339Nano) err = cursorRepo.UpdateCursor(ctx, cursorKey, newValue) if err != nil { t.Fatalf("Failed to update cursor: %v", err) } // 验证更新 value, err = cursorRepo.GetCursor(ctx, cursorKey) if err != nil { t.Fatalf("Failed to get cursor after update: %v", err) } if value != newValue { t.Errorf("Expected cursor value %s, got %s", newValue, value) } t.Logf("✅ PostgreSQL cursor operations test passed") } // TestPostgreSQL_RetryOperations 测试 PostgreSQL 重试操作 func TestPostgreSQL_RetryOperations(t *testing.T) { if testing.Short() { t.Skip("Skipping PostgreSQL integration test in short mode") } db, ok := setupPostgresDB(t) if !ok { t.Skip("PostgreSQL not available") return } defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() retryRepo := NewRetryRepository(db, log) opID := fmt.Sprintf("pg-retry-%d", time.Now().Unix()) // 添加重试记录 nextRetry := time.Now().Add(-1 * time.Second) // 过去的时间,立即可以重试 err := retryRepo.AddRetry(ctx, opID, "test error", nextRetry) if err != nil { t.Fatalf("Failed to add retry: %v", err) } // 查找待重试记录 records, err := retryRepo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("Failed to find pending retries: %v", err) } found := false for _, record := range records { if record.OpID == opID { found = true if record.RetryStatus != RetryStatusPending { t.Errorf("Expected status PENDING, got %v", record.RetryStatus) } break } } if !found { t.Error("Retry record not found") } // 增加重试次数 nextRetry2 := time.Now().Add(-1 * time.Second) err = retryRepo.IncrementRetry(ctx, opID, "retry error", nextRetry2) if err != nil { t.Fatalf("Failed to increment retry: %v", err) } // 标记为死信 err = retryRepo.MarkAsDeadLetter(ctx, opID, "max retries exceeded") if err != nil { t.Fatalf("Failed to mark as dead letter: %v", err) } // 验证死信状态(死信不应在待重试列表中) records, err = retryRepo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("Failed to find pending retries: %v", err) } for _, record := range records { if record.OpID == opID { t.Error("Dead letter record should not be in pending list") } } // 删除重试记录 err = retryRepo.DeleteRetry(ctx, opID) if err != nil { t.Fatalf("Failed to delete retry: %v", err) } t.Logf("✅ PostgreSQL retry operations test passed") } // TestPostgreSQL_PersistenceManager 测试 PostgreSQL 的 PersistenceManager func TestPostgreSQL_PersistenceManager(t *testing.T) { if testing.Short() { t.Skip("Skipping PostgreSQL integration test in short mode") } db, ok := setupPostgresDB(t) if !ok { t.Skip("PostgreSQL not available") return } defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() // 测试 DBOnly 策略 config := PersistenceConfig{ Strategy: StrategyDBOnly, } manager := NewPersistenceManager(db, config, log) op := createTestOperation(t, fmt.Sprintf("pg-manager-%d", time.Now().Unix())) err := manager.SaveOperation(ctx, op) if err != nil { t.Fatalf("Failed to save operation: %v", err) } // 验证保存到数据库 var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM operation WHERE op_id = $1", op.OpID).Scan(&count) if err != nil { t.Fatalf("Failed to query database: %v", err) } if count != 1 { t.Errorf("Expected 1 record, got %d", count) } t.Logf("✅ PostgreSQL PersistenceManager test passed") }