Files
go-trustlog/api/persistence/repository_test.go
ryan a90d853a6e refactor: 将 OpType 字段从枚举类型改为 string 类型
主要变更:
- Operation.OpType: Type → string
- NewFullOperation 参数: opType Type → opType string
- IsValidOpType 参数: opType Type → opType string
- operationMeta.OpType: *Type → *string
- queryclient.ListRequest.OpType: model.Type → string

优点:
- 更灵活,支持动态扩展操作类型
- 不再受限于预定义的枚举常量
- 简化类型转换逻辑

兼容性:
- Type 常量定义保持不变 (OpTypeCreate, OpTypeUpdate 等)
- 使用时需要 string() 转换: string(model.OpTypeCreate)
- 所有单元测试已更新并通过 (100%)

测试结果:
 api/adapter - PASS
 api/highclient - PASS
 api/logger - PASS
 api/model - PASS
 api/persistence - PASS
 api/queryclient - PASS
 internal/* - PASS
2025-12-24 16:48:00 +08:00

392 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package persistence
import (
"context"
"database/sql"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"go.yandata.net/iod/iod/go-trustlog/api/logger"
"go.yandata.net/iod/iod/go-trustlog/api/model"
)
// setupTestDB 创建测试用的 SQLite 内存数据库
func setupTestDB(t *testing.T) *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("failed to open test database: %v", err)
}
// 创建表
opDDL, cursorDDL, retryDDL, err := GetDialectDDL("sqlite3")
if err != nil {
t.Fatalf("failed to get DDL: %v", err)
}
if _, err := db.Exec(opDDL); err != nil {
t.Fatalf("failed to create operation table: %v", err)
}
if _, err := db.Exec(cursorDDL); err != nil {
t.Fatalf("failed to create cursor table: %v", err)
}
if _, err := db.Exec(retryDDL); err != nil {
t.Fatalf("failed to create retry table: %v", err)
}
return db
}
// createTestOperation 创建测试用的 Operation
func createTestOperation(t *testing.T, opID string) *model.Operation {
op, err := model.NewFullOperation(
model.OpSourceDOIP,
string(model.OpTypeCreate),
"10.1000",
"test-repo",
"10.1000/test-repo/"+opID,
"producer-001",
"test-actor",
[]byte(`{"test":"request"}`),
[]byte(`{"test":"response"}`),
time.Now(),
)
if err != nil {
t.Fatalf("failed to create test operation: %v", err)
}
op.OpID = opID // 覆盖自动生成的 ID
return op
}
func TestOperationRepository_Save(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewOperationRepository(db, log)
op := createTestOperation(t, "test-op-001")
// 设置 IP 字段
clientIP := "192.168.1.100"
serverIP := "10.0.0.50"
op.ClientIP = &clientIP
op.ServerIP = &serverIP
// 测试保存
err := repo.Save(ctx, op, StatusNotTrustlogged)
if err != nil {
t.Fatalf("failed to save operation: %v", err)
}
// 验证保存结果
savedOp, status, err := repo.FindByID(ctx, "test-op-001")
if err != nil {
t.Fatalf("failed to find operation: %v", err)
}
if savedOp.OpID != "test-op-001" {
t.Errorf("expected OpID to be 'test-op-001', got %s", savedOp.OpID)
}
if status != StatusNotTrustlogged {
t.Errorf("expected status to be StatusNotTrustlogged, got %v", status)
}
if savedOp.ClientIP == nil || *savedOp.ClientIP != "192.168.1.100" {
t.Error("ClientIP not saved correctly")
}
if savedOp.ServerIP == nil || *savedOp.ServerIP != "10.0.0.50" {
t.Error("ServerIP not saved correctly")
}
}
func TestOperationRepository_SaveWithNullIP(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewOperationRepository(db, log)
op := createTestOperation(t, "test-op-002")
// IP 字段保持为 nil
err := repo.Save(ctx, op, StatusNotTrustlogged)
if err != nil {
t.Fatalf("failed to save operation: %v", err)
}
savedOp, _, err := repo.FindByID(ctx, "test-op-002")
if err != nil {
t.Fatalf("failed to find operation: %v", err)
}
if savedOp.ClientIP != nil {
t.Error("ClientIP should be nil")
}
if savedOp.ServerIP != nil {
t.Error("ServerIP should be nil")
}
}
func TestOperationRepository_UpdateStatus(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewOperationRepository(db, log)
op := createTestOperation(t, "test-op-003")
// 先保存
err := repo.Save(ctx, op, StatusNotTrustlogged)
if err != nil {
t.Fatalf("failed to save operation: %v", err)
}
// 更新状态
err = repo.UpdateStatus(ctx, "test-op-003", StatusTrustlogged)
if err != nil {
t.Fatalf("failed to update status: %v", err)
}
// 验证更新结果
_, status, err := repo.FindByID(ctx, "test-op-003")
if err != nil {
t.Fatalf("failed to find operation: %v", err)
}
if status != StatusTrustlogged {
t.Errorf("expected status to be StatusTrustlogged, got %v", status)
}
}
func TestOperationRepository_FindUntrustlogged(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewOperationRepository(db, log)
// 保存多个操作
for i := 1; i <= 5; i++ {
op := createTestOperation(t, "test-op-00"+string(rune('0'+i)))
status := StatusNotTrustlogged
if i%2 == 0 {
status = StatusTrustlogged
}
err := repo.Save(ctx, op, status)
if err != nil {
t.Fatalf("failed to save operation %d: %v", i, err)
}
}
// 查询未存证的操作
ops, err := repo.FindUntrustlogged(ctx, 10)
if err != nil {
t.Fatalf("failed to find untrustlogged operations: %v", err)
}
// 应该有 3 个未存证的操作1, 3, 5
if len(ops) != 3 {
t.Errorf("expected 3 untrustlogged operations, got %d", len(ops))
}
}
func TestCursorRepository_GetAndUpdate(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewCursorRepository(db, log)
cursorKey := "test-cursor"
// 初始化游标
now := time.Now().Format(time.RFC3339Nano)
err := repo.InitCursor(ctx, cursorKey, now)
if err != nil {
t.Fatalf("failed to init cursor: %v", err)
}
// 获取游标值
cursorValue, err := repo.GetCursor(ctx, cursorKey)
if err != nil {
t.Fatalf("failed to get cursor: %v", err)
}
if cursorValue != now {
t.Errorf("expected cursor value to be %s, got %s", now, cursorValue)
}
// 更新游标
newTime := time.Now().Add(1 * time.Hour).Format(time.RFC3339Nano)
err = repo.UpdateCursor(ctx, cursorKey, newTime)
if err != nil {
t.Fatalf("failed to update cursor: %v", err)
}
// 验证更新结果
cursorValue, err = repo.GetCursor(ctx, cursorKey)
if err != nil {
t.Fatalf("failed to get cursor: %v", err)
}
if cursorValue != newTime {
t.Errorf("expected cursor value to be %s, got %s", newTime, cursorValue)
}
}
func TestRetryRepository_AddAndFind(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewRetryRepository(db, log)
// 添加重试记录(立即可以重试)
nextRetry := time.Now().Add(-1 * time.Second) // 过去的时间,立即可以查询到
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
if err != nil {
t.Fatalf("failed to add retry: %v", err)
}
// 查找待重试的记录
records, err := repo.FindPendingRetries(ctx, 10)
if err != nil {
t.Fatalf("failed to find pending retries: %v", err)
}
if len(records) != 1 {
t.Errorf("expected 1 retry record, got %d", len(records))
}
if len(records) > 0 {
if records[0].OpID != "test-op-001" {
t.Errorf("expected OpID to be 'test-op-001', got %s", records[0].OpID)
}
if records[0].RetryStatus != RetryStatusPending {
t.Errorf("expected status to be PENDING, got %v", records[0].RetryStatus)
}
}
}
func TestRetryRepository_IncrementRetry(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewRetryRepository(db, log)
// 添加重试记录
nextRetry := time.Now().Add(-1 * time.Second)
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
if err != nil {
t.Fatalf("failed to add retry: %v", err)
}
// 增加重试次数(立即可以重试)
nextRetry2 := time.Now().Add(-1 * time.Second)
err = repo.IncrementRetry(ctx, "test-op-001", "test error 2", nextRetry2)
if err != nil {
t.Fatalf("failed to increment retry: %v", err)
}
// 验证重试次数
records, err := repo.FindPendingRetries(ctx, 10)
if err != nil {
t.Fatalf("failed to find pending retries: %v", err)
}
if len(records) != 1 {
t.Fatalf("expected 1 retry record, got %d", len(records))
}
if records[0].RetryCount != 1 {
t.Errorf("expected RetryCount to be 1, got %d", records[0].RetryCount)
}
if records[0].RetryStatus != RetryStatusRetrying {
t.Errorf("expected status to be RETRYING, got %v", records[0].RetryStatus)
}
}
func TestRetryRepository_MarkAsDeadLetter(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewRetryRepository(db, log)
// 添加重试记录
nextRetry := time.Now().Add(-1 * time.Second)
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
if err != nil {
t.Fatalf("failed to add retry: %v", err)
}
// 标记为死信
err = repo.MarkAsDeadLetter(ctx, "test-op-001", "max retries exceeded")
if err != nil {
t.Fatalf("failed to mark as dead letter: %v", err)
}
// 验证状态(死信不应该在待重试列表中)
records, err := repo.FindPendingRetries(ctx, 10)
if err != nil {
t.Fatalf("failed to find pending retries: %v", err)
}
if len(records) != 0 {
t.Errorf("expected 0 pending retry records, got %d", len(records))
}
}
func TestRetryRepository_DeleteRetry(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
log := logger.GetGlobalLogger()
repo := NewRetryRepository(db, log)
// 添加重试记录
nextRetry := time.Now().Add(-1 * time.Second)
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
if err != nil {
t.Fatalf("failed to add retry: %v", err)
}
// 删除重试记录
err = repo.DeleteRetry(ctx, "test-op-001")
if err != nil {
t.Fatalf("failed to delete retry: %v", err)
}
// 验证已删除
records, err := repo.FindPendingRetries(ctx, 10)
if err != nil {
t.Fatalf("failed to find pending retries: %v", err)
}
if len(records) != 0 {
t.Errorf("expected 0 retry records, got %d", len(records))
}
}