package persistence import ( "context" "database/sql" "testing" "time" _ "github.com/mattn/go-sqlite3" "go.yandata.net/iod/iod/go-trustlog/api/logger" "go.yandata.net/iod/iod/go-trustlog/api/model" ) // setupTestDB 创建测试用的 SQLite 内存数据库 func setupTestDB(t *testing.T) *sql.DB { db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("failed to open test database: %v", err) } // 创建表 opDDL, cursorDDL, retryDDL, err := GetDialectDDL("sqlite3") if err != nil { t.Fatalf("failed to get DDL: %v", err) } if _, err := db.Exec(opDDL); err != nil { t.Fatalf("failed to create operation table: %v", err) } if _, err := db.Exec(cursorDDL); err != nil { t.Fatalf("failed to create cursor table: %v", err) } if _, err := db.Exec(retryDDL); err != nil { t.Fatalf("failed to create retry table: %v", err) } return db } // createTestOperation 创建测试用的 Operation func createTestOperation(t *testing.T, opID string) *model.Operation { op, err := model.NewFullOperation( model.OpSourceDOIP, model.OpTypeCreate, "10.1000", "test-repo", "10.1000/test-repo/"+opID, "producer-001", "test-actor", []byte(`{"test":"request"}`), []byte(`{"test":"response"}`), time.Now(), ) if err != nil { t.Fatalf("failed to create test operation: %v", err) } op.OpID = opID // 覆盖自动生成的 ID return op } func TestOperationRepository_Save(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) op := createTestOperation(t, "test-op-001") // 设置 IP 字段 clientIP := "192.168.1.100" serverIP := "10.0.0.50" op.ClientIP = &clientIP op.ServerIP = &serverIP // 测试保存 err := repo.Save(ctx, op, StatusNotTrustlogged) if err != nil { t.Fatalf("failed to save operation: %v", err) } // 验证保存结果 savedOp, status, err := repo.FindByID(ctx, "test-op-001") if err != nil { t.Fatalf("failed to find operation: %v", err) } if savedOp.OpID != "test-op-001" { t.Errorf("expected OpID to be 'test-op-001', got %s", savedOp.OpID) } if status != StatusNotTrustlogged { t.Errorf("expected status to be StatusNotTrustlogged, got %v", status) } if savedOp.ClientIP == nil || *savedOp.ClientIP != "192.168.1.100" { t.Error("ClientIP not saved correctly") } if savedOp.ServerIP == nil || *savedOp.ServerIP != "10.0.0.50" { t.Error("ServerIP not saved correctly") } } func TestOperationRepository_SaveWithNullIP(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) op := createTestOperation(t, "test-op-002") // IP 字段保持为 nil err := repo.Save(ctx, op, StatusNotTrustlogged) if err != nil { t.Fatalf("failed to save operation: %v", err) } savedOp, _, err := repo.FindByID(ctx, "test-op-002") if err != nil { t.Fatalf("failed to find operation: %v", err) } if savedOp.ClientIP != nil { t.Error("ClientIP should be nil") } if savedOp.ServerIP != nil { t.Error("ServerIP should be nil") } } func TestOperationRepository_UpdateStatus(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) op := createTestOperation(t, "test-op-003") // 先保存 err := repo.Save(ctx, op, StatusNotTrustlogged) if err != nil { t.Fatalf("failed to save operation: %v", err) } // 更新状态 err = repo.UpdateStatus(ctx, "test-op-003", StatusTrustlogged) if err != nil { t.Fatalf("failed to update status: %v", err) } // 验证更新结果 _, status, err := repo.FindByID(ctx, "test-op-003") if err != nil { t.Fatalf("failed to find operation: %v", err) } if status != StatusTrustlogged { t.Errorf("expected status to be StatusTrustlogged, got %v", status) } } func TestOperationRepository_FindUntrustlogged(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewOperationRepository(db, log) // 保存多个操作 for i := 1; i <= 5; i++ { op := createTestOperation(t, "test-op-00"+string(rune('0'+i))) status := StatusNotTrustlogged if i%2 == 0 { status = StatusTrustlogged } err := repo.Save(ctx, op, status) if err != nil { t.Fatalf("failed to save operation %d: %v", i, err) } } // 查询未存证的操作 ops, err := repo.FindUntrustlogged(ctx, 10) if err != nil { t.Fatalf("failed to find untrustlogged operations: %v", err) } // 应该有 3 个未存证的操作(1, 3, 5) if len(ops) != 3 { t.Errorf("expected 3 untrustlogged operations, got %d", len(ops)) } } func TestCursorRepository_GetAndUpdate(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewCursorRepository(db, log) cursorKey := "test-cursor" // 初始化游标 now := time.Now().Format(time.RFC3339Nano) err := repo.InitCursor(ctx, cursorKey, now) if err != nil { t.Fatalf("failed to init cursor: %v", err) } // 获取游标值 cursorValue, err := repo.GetCursor(ctx, cursorKey) if err != nil { t.Fatalf("failed to get cursor: %v", err) } if cursorValue != now { t.Errorf("expected cursor value to be %s, got %s", now, cursorValue) } // 更新游标 newTime := time.Now().Add(1 * time.Hour).Format(time.RFC3339Nano) err = repo.UpdateCursor(ctx, cursorKey, newTime) if err != nil { t.Fatalf("failed to update cursor: %v", err) } // 验证更新结果 cursorValue, err = repo.GetCursor(ctx, cursorKey) if err != nil { t.Fatalf("failed to get cursor: %v", err) } if cursorValue != newTime { t.Errorf("expected cursor value to be %s, got %s", newTime, cursorValue) } } func TestRetryRepository_AddAndFind(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewRetryRepository(db, log) // 添加重试记录(立即可以重试) nextRetry := time.Now().Add(-1 * time.Second) // 过去的时间,立即可以查询到 err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry) if err != nil { t.Fatalf("failed to add retry: %v", err) } // 查找待重试的记录 records, err := repo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("failed to find pending retries: %v", err) } if len(records) != 1 { t.Errorf("expected 1 retry record, got %d", len(records)) } if len(records) > 0 { if records[0].OpID != "test-op-001" { t.Errorf("expected OpID to be 'test-op-001', got %s", records[0].OpID) } if records[0].RetryStatus != RetryStatusPending { t.Errorf("expected status to be PENDING, got %v", records[0].RetryStatus) } } } func TestRetryRepository_IncrementRetry(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewRetryRepository(db, log) // 添加重试记录 nextRetry := time.Now().Add(-1 * time.Second) err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry) if err != nil { t.Fatalf("failed to add retry: %v", err) } // 增加重试次数(立即可以重试) nextRetry2 := time.Now().Add(-1 * time.Second) err = repo.IncrementRetry(ctx, "test-op-001", "test error 2", nextRetry2) if err != nil { t.Fatalf("failed to increment retry: %v", err) } // 验证重试次数 records, err := repo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("failed to find pending retries: %v", err) } if len(records) != 1 { t.Fatalf("expected 1 retry record, got %d", len(records)) } if records[0].RetryCount != 1 { t.Errorf("expected RetryCount to be 1, got %d", records[0].RetryCount) } if records[0].RetryStatus != RetryStatusRetrying { t.Errorf("expected status to be RETRYING, got %v", records[0].RetryStatus) } } func TestRetryRepository_MarkAsDeadLetter(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewRetryRepository(db, log) // 添加重试记录 nextRetry := time.Now().Add(-1 * time.Second) err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry) if err != nil { t.Fatalf("failed to add retry: %v", err) } // 标记为死信 err = repo.MarkAsDeadLetter(ctx, "test-op-001", "max retries exceeded") if err != nil { t.Fatalf("failed to mark as dead letter: %v", err) } // 验证状态(死信不应该在待重试列表中) records, err := repo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("failed to find pending retries: %v", err) } if len(records) != 0 { t.Errorf("expected 0 pending retry records, got %d", len(records)) } } func TestRetryRepository_DeleteRetry(t *testing.T) { db := setupTestDB(t) defer db.Close() ctx := context.Background() log := logger.GetGlobalLogger() repo := NewRetryRepository(db, log) // 添加重试记录 nextRetry := time.Now().Add(-1 * time.Second) err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry) if err != nil { t.Fatalf("failed to add retry: %v", err) } // 删除重试记录 err = repo.DeleteRetry(ctx, "test-op-001") if err != nil { t.Fatalf("failed to delete retry: %v", err) } // 验证已删除 records, err := repo.FindPendingRetries(ctx, 10) if err != nil { t.Fatalf("failed to find pending retries: %v", err) } if len(records) != 0 { t.Errorf("expected 0 retry records, got %d", len(records)) } }