主要变更: - Operation.OpType: Type → string - NewFullOperation 参数: opType Type → opType string - IsValidOpType 参数: opType Type → opType string - operationMeta.OpType: *Type → *string - queryclient.ListRequest.OpType: model.Type → string 优点: - 更灵活,支持动态扩展操作类型 - 不再受限于预定义的枚举常量 - 简化类型转换逻辑 兼容性: - Type 常量定义保持不变 (OpTypeCreate, OpTypeUpdate 等) - 使用时需要 string() 转换: string(model.OpTypeCreate) - 所有单元测试已更新并通过 (100%) 测试结果: ✅ api/adapter - PASS ✅ api/highclient - PASS ✅ api/logger - PASS ✅ api/model - PASS ✅ api/persistence - PASS ✅ api/queryclient - PASS ✅ internal/* - PASS
392 lines
9.4 KiB
Go
392 lines
9.4 KiB
Go
package persistence
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"testing"
|
||
"time"
|
||
|
||
_ "github.com/mattn/go-sqlite3"
|
||
|
||
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||
"go.yandata.net/iod/iod/go-trustlog/api/model"
|
||
)
|
||
|
||
// setupTestDB 创建测试用的 SQLite 内存数据库
|
||
func setupTestDB(t *testing.T) *sql.DB {
|
||
db, err := sql.Open("sqlite3", ":memory:")
|
||
if err != nil {
|
||
t.Fatalf("failed to open test database: %v", err)
|
||
}
|
||
|
||
// 创建表
|
||
opDDL, cursorDDL, retryDDL, err := GetDialectDDL("sqlite3")
|
||
if err != nil {
|
||
t.Fatalf("failed to get DDL: %v", err)
|
||
}
|
||
|
||
if _, err := db.Exec(opDDL); err != nil {
|
||
t.Fatalf("failed to create operation table: %v", err)
|
||
}
|
||
|
||
if _, err := db.Exec(cursorDDL); err != nil {
|
||
t.Fatalf("failed to create cursor table: %v", err)
|
||
}
|
||
|
||
if _, err := db.Exec(retryDDL); err != nil {
|
||
t.Fatalf("failed to create retry table: %v", err)
|
||
}
|
||
|
||
return db
|
||
}
|
||
|
||
// createTestOperation 创建测试用的 Operation
|
||
func createTestOperation(t *testing.T, opID string) *model.Operation {
|
||
op, err := model.NewFullOperation(
|
||
model.OpSourceDOIP,
|
||
string(model.OpTypeCreate),
|
||
"10.1000",
|
||
"test-repo",
|
||
"10.1000/test-repo/"+opID,
|
||
"producer-001",
|
||
"test-actor",
|
||
[]byte(`{"test":"request"}`),
|
||
[]byte(`{"test":"response"}`),
|
||
time.Now(),
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("failed to create test operation: %v", err)
|
||
}
|
||
|
||
op.OpID = opID // 覆盖自动生成的 ID
|
||
return op
|
||
}
|
||
|
||
func TestOperationRepository_Save(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewOperationRepository(db, log)
|
||
|
||
op := createTestOperation(t, "test-op-001")
|
||
|
||
// 设置 IP 字段
|
||
clientIP := "192.168.1.100"
|
||
serverIP := "10.0.0.50"
|
||
op.ClientIP = &clientIP
|
||
op.ServerIP = &serverIP
|
||
|
||
// 测试保存
|
||
err := repo.Save(ctx, op, StatusNotTrustlogged)
|
||
if err != nil {
|
||
t.Fatalf("failed to save operation: %v", err)
|
||
}
|
||
|
||
// 验证保存结果
|
||
savedOp, status, err := repo.FindByID(ctx, "test-op-001")
|
||
if err != nil {
|
||
t.Fatalf("failed to find operation: %v", err)
|
||
}
|
||
|
||
if savedOp.OpID != "test-op-001" {
|
||
t.Errorf("expected OpID to be 'test-op-001', got %s", savedOp.OpID)
|
||
}
|
||
|
||
if status != StatusNotTrustlogged {
|
||
t.Errorf("expected status to be StatusNotTrustlogged, got %v", status)
|
||
}
|
||
|
||
if savedOp.ClientIP == nil || *savedOp.ClientIP != "192.168.1.100" {
|
||
t.Error("ClientIP not saved correctly")
|
||
}
|
||
|
||
if savedOp.ServerIP == nil || *savedOp.ServerIP != "10.0.0.50" {
|
||
t.Error("ServerIP not saved correctly")
|
||
}
|
||
}
|
||
|
||
func TestOperationRepository_SaveWithNullIP(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewOperationRepository(db, log)
|
||
|
||
op := createTestOperation(t, "test-op-002")
|
||
// IP 字段保持为 nil
|
||
|
||
err := repo.Save(ctx, op, StatusNotTrustlogged)
|
||
if err != nil {
|
||
t.Fatalf("failed to save operation: %v", err)
|
||
}
|
||
|
||
savedOp, _, err := repo.FindByID(ctx, "test-op-002")
|
||
if err != nil {
|
||
t.Fatalf("failed to find operation: %v", err)
|
||
}
|
||
|
||
if savedOp.ClientIP != nil {
|
||
t.Error("ClientIP should be nil")
|
||
}
|
||
|
||
if savedOp.ServerIP != nil {
|
||
t.Error("ServerIP should be nil")
|
||
}
|
||
}
|
||
|
||
func TestOperationRepository_UpdateStatus(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewOperationRepository(db, log)
|
||
|
||
op := createTestOperation(t, "test-op-003")
|
||
|
||
// 先保存
|
||
err := repo.Save(ctx, op, StatusNotTrustlogged)
|
||
if err != nil {
|
||
t.Fatalf("failed to save operation: %v", err)
|
||
}
|
||
|
||
// 更新状态
|
||
err = repo.UpdateStatus(ctx, "test-op-003", StatusTrustlogged)
|
||
if err != nil {
|
||
t.Fatalf("failed to update status: %v", err)
|
||
}
|
||
|
||
// 验证更新结果
|
||
_, status, err := repo.FindByID(ctx, "test-op-003")
|
||
if err != nil {
|
||
t.Fatalf("failed to find operation: %v", err)
|
||
}
|
||
|
||
if status != StatusTrustlogged {
|
||
t.Errorf("expected status to be StatusTrustlogged, got %v", status)
|
||
}
|
||
}
|
||
|
||
func TestOperationRepository_FindUntrustlogged(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewOperationRepository(db, log)
|
||
|
||
// 保存多个操作
|
||
for i := 1; i <= 5; i++ {
|
||
op := createTestOperation(t, "test-op-00"+string(rune('0'+i)))
|
||
status := StatusNotTrustlogged
|
||
if i%2 == 0 {
|
||
status = StatusTrustlogged
|
||
}
|
||
err := repo.Save(ctx, op, status)
|
||
if err != nil {
|
||
t.Fatalf("failed to save operation %d: %v", i, err)
|
||
}
|
||
}
|
||
|
||
// 查询未存证的操作
|
||
ops, err := repo.FindUntrustlogged(ctx, 10)
|
||
if err != nil {
|
||
t.Fatalf("failed to find untrustlogged operations: %v", err)
|
||
}
|
||
|
||
// 应该有 3 个未存证的操作(1, 3, 5)
|
||
if len(ops) != 3 {
|
||
t.Errorf("expected 3 untrustlogged operations, got %d", len(ops))
|
||
}
|
||
}
|
||
|
||
func TestCursorRepository_GetAndUpdate(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewCursorRepository(db, log)
|
||
|
||
cursorKey := "test-cursor"
|
||
|
||
// 初始化游标
|
||
now := time.Now().Format(time.RFC3339Nano)
|
||
err := repo.InitCursor(ctx, cursorKey, now)
|
||
if err != nil {
|
||
t.Fatalf("failed to init cursor: %v", err)
|
||
}
|
||
|
||
// 获取游标值
|
||
cursorValue, err := repo.GetCursor(ctx, cursorKey)
|
||
if err != nil {
|
||
t.Fatalf("failed to get cursor: %v", err)
|
||
}
|
||
|
||
if cursorValue != now {
|
||
t.Errorf("expected cursor value to be %s, got %s", now, cursorValue)
|
||
}
|
||
|
||
// 更新游标
|
||
newTime := time.Now().Add(1 * time.Hour).Format(time.RFC3339Nano)
|
||
err = repo.UpdateCursor(ctx, cursorKey, newTime)
|
||
if err != nil {
|
||
t.Fatalf("failed to update cursor: %v", err)
|
||
}
|
||
|
||
// 验证更新结果
|
||
cursorValue, err = repo.GetCursor(ctx, cursorKey)
|
||
if err != nil {
|
||
t.Fatalf("failed to get cursor: %v", err)
|
||
}
|
||
|
||
if cursorValue != newTime {
|
||
t.Errorf("expected cursor value to be %s, got %s", newTime, cursorValue)
|
||
}
|
||
}
|
||
|
||
func TestRetryRepository_AddAndFind(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewRetryRepository(db, log)
|
||
|
||
// 添加重试记录(立即可以重试)
|
||
nextRetry := time.Now().Add(-1 * time.Second) // 过去的时间,立即可以查询到
|
||
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
||
if err != nil {
|
||
t.Fatalf("failed to add retry: %v", err)
|
||
}
|
||
|
||
// 查找待重试的记录
|
||
records, err := repo.FindPendingRetries(ctx, 10)
|
||
if err != nil {
|
||
t.Fatalf("failed to find pending retries: %v", err)
|
||
}
|
||
|
||
if len(records) != 1 {
|
||
t.Errorf("expected 1 retry record, got %d", len(records))
|
||
}
|
||
|
||
if len(records) > 0 {
|
||
if records[0].OpID != "test-op-001" {
|
||
t.Errorf("expected OpID to be 'test-op-001', got %s", records[0].OpID)
|
||
}
|
||
|
||
if records[0].RetryStatus != RetryStatusPending {
|
||
t.Errorf("expected status to be PENDING, got %v", records[0].RetryStatus)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestRetryRepository_IncrementRetry(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewRetryRepository(db, log)
|
||
|
||
// 添加重试记录
|
||
nextRetry := time.Now().Add(-1 * time.Second)
|
||
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
||
if err != nil {
|
||
t.Fatalf("failed to add retry: %v", err)
|
||
}
|
||
|
||
// 增加重试次数(立即可以重试)
|
||
nextRetry2 := time.Now().Add(-1 * time.Second)
|
||
err = repo.IncrementRetry(ctx, "test-op-001", "test error 2", nextRetry2)
|
||
if err != nil {
|
||
t.Fatalf("failed to increment retry: %v", err)
|
||
}
|
||
|
||
// 验证重试次数
|
||
records, err := repo.FindPendingRetries(ctx, 10)
|
||
if err != nil {
|
||
t.Fatalf("failed to find pending retries: %v", err)
|
||
}
|
||
|
||
if len(records) != 1 {
|
||
t.Fatalf("expected 1 retry record, got %d", len(records))
|
||
}
|
||
|
||
if records[0].RetryCount != 1 {
|
||
t.Errorf("expected RetryCount to be 1, got %d", records[0].RetryCount)
|
||
}
|
||
|
||
if records[0].RetryStatus != RetryStatusRetrying {
|
||
t.Errorf("expected status to be RETRYING, got %v", records[0].RetryStatus)
|
||
}
|
||
}
|
||
|
||
func TestRetryRepository_MarkAsDeadLetter(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewRetryRepository(db, log)
|
||
|
||
// 添加重试记录
|
||
nextRetry := time.Now().Add(-1 * time.Second)
|
||
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
||
if err != nil {
|
||
t.Fatalf("failed to add retry: %v", err)
|
||
}
|
||
|
||
// 标记为死信
|
||
err = repo.MarkAsDeadLetter(ctx, "test-op-001", "max retries exceeded")
|
||
if err != nil {
|
||
t.Fatalf("failed to mark as dead letter: %v", err)
|
||
}
|
||
|
||
// 验证状态(死信不应该在待重试列表中)
|
||
records, err := repo.FindPendingRetries(ctx, 10)
|
||
if err != nil {
|
||
t.Fatalf("failed to find pending retries: %v", err)
|
||
}
|
||
|
||
if len(records) != 0 {
|
||
t.Errorf("expected 0 pending retry records, got %d", len(records))
|
||
}
|
||
}
|
||
|
||
func TestRetryRepository_DeleteRetry(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer db.Close()
|
||
|
||
ctx := context.Background()
|
||
log := logger.GetGlobalLogger()
|
||
repo := NewRetryRepository(db, log)
|
||
|
||
// 添加重试记录
|
||
nextRetry := time.Now().Add(-1 * time.Second)
|
||
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
||
if err != nil {
|
||
t.Fatalf("failed to add retry: %v", err)
|
||
}
|
||
|
||
// 删除重试记录
|
||
err = repo.DeleteRetry(ctx, "test-op-001")
|
||
if err != nil {
|
||
t.Fatalf("failed to delete retry: %v", err)
|
||
}
|
||
|
||
// 验证已删除
|
||
records, err := repo.FindPendingRetries(ctx, 10)
|
||
if err != nil {
|
||
t.Fatalf("failed to find pending retries: %v", err)
|
||
}
|
||
|
||
if len(records) != 0 {
|
||
t.Errorf("expected 0 retry records, got %d", len(records))
|
||
}
|
||
}
|
||
|