package persistence import ( "context" "database/sql" "testing" "time" _ "github.com/mattn/go-sqlite3" "go.yandata.net/iod/iod/go-trustlog/api/logger" ) func TestPersistenceManager_DBOnly(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() config := PersistenceConfig{ Strategy: StrategyDBOnly, } manager := NewPersistenceManager(db, config, log) if manager == nil { t.Fatal("failed to create PersistenceManager") } op := createTestOperation(t, "manager-test-001") err := manager.SaveOperation(ctx, op) if err != nil { t.Fatalf("failed to save operation: %v", err) } // 验证已保存到数据库 var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM operation WHERE op_id = ?", "manager-test-001").Scan(&count) if err != nil { t.Fatalf("failed to query database: %v", err) } if count != 1 { t.Errorf("expected 1 record, got %d", count) } } func TestPersistenceManager_TrustlogOnly(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() config := PersistenceConfig{ Strategy: StrategyTrustlogOnly, } manager := NewPersistenceManager(db, config, log) if manager == nil { t.Fatal("failed to create PersistenceManager") } op := createTestOperation(t, "manager-test-002") err := manager.SaveOperation(ctx, op) if err != nil { t.Fatalf("failed to save operation: %v", err) } // TrustlogOnly 不会保存到数据库,应该查不到 var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM operation WHERE op_id = ?", "manager-test-002").Scan(&count) if err != nil { t.Fatalf("failed to query database: %v", err) } if count != 0 { t.Errorf("expected 0 records (TrustlogOnly should not save to DB), got %d", count) } } func TestPersistenceManager_DBAndTrustlog(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() config := PersistenceConfig{ Strategy: StrategyDBAndTrustlog, } manager := NewPersistenceManager(db, config, log) if manager == nil { t.Fatal("failed to create PersistenceManager") } op := createTestOperation(t, "manager-test-003") err := manager.SaveOperation(ctx, op) if err != nil { t.Fatalf("failed to save operation: %v", err) } // DBAndTrustlog 会保存到数据库,状态为 NOT_TRUSTLOGGED var status string err = db.QueryRowContext(ctx, "SELECT trustlog_status FROM operation WHERE op_id = ?", "manager-test-003").Scan(&status) if err != nil { t.Fatalf("failed to query database: %v", err) } if status != "NOT_TRUSTLOGGED" { t.Errorf("expected status to be NOT_TRUSTLOGGED, got %s", status) } } func TestPersistenceManager_GetRepositories(t *testing.T) { db := setupTestDB(t) defer db.Close() log := logger.GetGlobalLogger() config := PersistenceConfig{ Strategy: StrategyDBOnly, } manager := NewPersistenceManager(db, config, log) // 测试获取各个 Repository opRepo := manager.GetOperationRepo() if opRepo == nil { t.Error("GetOperationRepo returned nil") } cursorRepo := manager.GetCursorRepo() if cursorRepo == nil { t.Error("GetCursorRepo returned nil") } retryRepo := manager.GetRetryRepo() if retryRepo == nil { t.Error("GetRetryRepo returned nil") } } func TestPersistenceManager_Close(t *testing.T) { db := setupTestDB(t) defer db.Close() log := logger.GetGlobalLogger() config := PersistenceConfig{ Strategy: StrategyDBOnly, } manager := NewPersistenceManager(db, config, log) err := manager.Close() if err != nil { t.Errorf("Close returned error: %v", err) } } func TestPersistenceManager_InitSchema(t *testing.T) { // 创建一个空数据库 db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("failed to open database: %v", err) } defer db.Close() log := logger.GetGlobalLogger() config := PersistenceConfig{ Strategy: StrategyDBOnly, } manager := NewPersistenceManager(db, config, log) // 手动调用 InitSchema(如果 NewPersistenceManager 没有自动调用) err = manager.InitSchema(context.Background(), "sqlite3") if err != nil { t.Fatalf("InitSchema failed: %v", err) } // 验证表已创建 var count int err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='operation'").Scan(&count) if err != nil { t.Fatalf("failed to query schema: %v", err) } if count != 1 { t.Errorf("expected operation table to exist, got count=%d", count) } } func TestOperationRepository_SaveTx(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) // 开始事务 tx, err := db.BeginTx(ctx, nil) if err != nil { t.Fatalf("failed to begin transaction: %v", err) } op := createTestOperation(t, "tx-test-001") // 在事务中保存 err = repo.SaveTx(ctx, tx, op, StatusNotTrustlogged) if err != nil { tx.Rollback() t.Fatalf("failed to save operation in transaction: %v", err) } // 提交事务 err = tx.Commit() if err != nil { t.Fatalf("failed to commit transaction: %v", err) } // 验证保存成功 savedOp, _, err := repo.FindByID(ctx, "tx-test-001") if err != nil { t.Fatalf("failed to find operation: %v", err) } if savedOp.OpID != "tx-test-001" { t.Errorf("expected OpID to be 'tx-test-001', got %s", savedOp.OpID) } } func TestOperationRepository_SaveTxRollback(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) // 开始事务 tx, err := db.BeginTx(ctx, nil) if err != nil { t.Fatalf("failed to begin transaction: %v", err) } op := createTestOperation(t, "tx-test-002") // 在事务中保存 err = repo.SaveTx(ctx, tx, op, StatusNotTrustlogged) if err != nil { tx.Rollback() t.Fatalf("failed to save operation in transaction: %v", err) } // 回滚事务 err = tx.Rollback() if err != nil { t.Fatalf("failed to rollback transaction: %v", err) } // 验证未保存 _, _, err = repo.FindByID(ctx, "tx-test-002") if err == nil { t.Error("expected error when finding rolled back operation, got nil") } } func TestRetryRepository_AddRetryTx(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewRetryRepository(db, log) // 开始事务 tx, err := db.BeginTx(ctx, nil) if err != nil { t.Fatalf("failed to begin transaction: %v", err) } nextRetry := time.Now().Add(-1 * time.Second) err = repo.AddRetryTx(ctx, tx, "tx-retry-001", "test error", nextRetry) if err != nil { tx.Rollback() t.Fatalf("failed to add retry in transaction: %v", err) } // 提交事务 err = tx.Commit() if err != nil { t.Fatalf("failed to commit transaction: %v", err) } // 验证已保存 records, err := repo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("failed to find pending retries: %v", err) } found := false for _, record := range records { if record.OpID == "tx-retry-001" { found = true break } } if !found { t.Error("expected to find retry record 'tx-retry-001'") } } func TestGetDialectDDL_AllDrivers(t *testing.T) { drivers := []string{"sqlite3", "postgres", "mysql"} for _, driver := range drivers { t.Run(driver, func(t *testing.T) { opDDL, cursorDDL, retryDDL, err := GetDialectDDL(driver) if err != nil { t.Fatalf("GetDialectDDL(%s) failed: %v", driver, err) } if opDDL == "" { t.Errorf("opDDL is empty for driver %s", driver) } if cursorDDL == "" { t.Errorf("cursorDDL is empty for driver %s", driver) } if retryDDL == "" { t.Errorf("retryDDL is empty for driver %s", driver) } }) } } func TestGetDialectDDL_UnknownDriver(t *testing.T) { // GetDialectDDL 对未知驱动返回通用 SQL(而不是错误) opDDL, cursorDDL, retryDDL, err := GetDialectDDL("unknown-driver") if err != nil { t.Fatalf("GetDialectDDL should not error for unknown driver, got: %v", err) } // 应该返回非空的 DDL if opDDL == "" { t.Error("expected non-empty operation DDL") } if cursorDDL == "" { t.Error("expected non-empty cursor DDL") } if retryDDL == "" { t.Error("expected non-empty retry DDL") } }