diff --git a/api/persistence/client.go b/api/persistence/client.go index 4fd4b3b..13557c1 100644 --- a/api/persistence/client.go +++ b/api/persistence/client.go @@ -396,3 +396,45 @@ func (c *PersistenceClient) Close() error { c.logger.Info("persistence client closed successfully") return nil } + +// QueryOperations 查询操作记录(支持分页、筛选、排序) +func (c *PersistenceClient) QueryOperations(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) { + if c.manager == nil { + return nil, fmt.Errorf("persistence manager not initialized") + } + + repo := c.manager.GetOperationRepo() + if repo == nil { + return nil, fmt.Errorf("operation repository not available") + } + + return repo.Query(ctx, req) +} + +// CountOperations 统计符合条件的操作记录数 +func (c *PersistenceClient) CountOperations(ctx context.Context, req *OperationQueryRequest) (int64, error) { + if c.manager == nil { + return 0, fmt.Errorf("persistence manager not initialized") + } + + repo := c.manager.GetOperationRepo() + if repo == nil { + return 0, fmt.Errorf("operation repository not available") + } + + return repo.Count(ctx, req) +} + +// GetOperationByID 根据 OpID 查询单条操作记录 +func (c *PersistenceClient) GetOperationByID(ctx context.Context, opID string) (*model.Operation, TrustlogStatus, error) { + if c.manager == nil { + return nil, "", fmt.Errorf("persistence manager not initialized") + } + + repo := c.manager.GetOperationRepo() + if repo == nil { + return nil, "", fmt.Errorf("operation repository not available") + } + + return repo.FindByID(ctx, opID) +} diff --git a/api/persistence/cluster_safety_test.go b/api/persistence/cluster_safety_test.go index 40bcefd..8434a99 100644 --- a/api/persistence/cluster_safety_test.go +++ b/api/persistence/cluster_safety_test.go @@ -56,6 +56,12 @@ func TestClusterSafety_MultipleCursorWorkers(t *testing.T) { t.Log("✅ PostgreSQL connected") + // 确保schema是最新的(添加可能缺失的列) + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS op_hash VARCHAR(128)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS sign VARCHAR(512)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS timestamp TIMESTAMP") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP") + // 创建测试数据:50 条未存证记录 operationCount := 50 timestamp := time.Now().Unix() @@ -63,12 +69,14 @@ func TestClusterSafety_MultipleCursorWorkers(t *testing.T) { opID := fmt.Sprintf("cluster-test-%d-%d", timestamp, i) _, err := db.Exec(` INSERT INTO operation ( - op_id, op_actor, doid, producer_id, + op_id, op_actor, doid, producer_id, + request_body_hash, response_body_hash, op_hash, sign, op_source, op_type, do_prefix, do_repository, - trustlog_status, created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW()) + trustlog_status, timestamp, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, NOW()) `, opID, "cluster-tester", fmt.Sprintf("cluster/test/%d", i), "cluster-producer", - "DOIP", "CREATE", "cluster-test", "cluster-repo", "NOT_TRUSTLOGGED") + "req-hash", "resp-hash", "op-hash", "signature", + "DOIP", "CREATE", "cluster-test", "cluster-repo", "NOT_TRUSTLOGGED", time.Now()) if err != nil { t.Fatalf("Failed to create test data: %v", err) diff --git a/api/persistence/cursor_init_verification_test.go b/api/persistence/cursor_init_verification_test.go index edca9a2..9f7af09 100644 --- a/api/persistence/cursor_init_verification_test.go +++ b/api/persistence/cursor_init_verification_test.go @@ -54,6 +54,12 @@ func TestCursorInitialization(t *testing.T) { t.Log("✅ PostgreSQL connected and cleaned") + // 确保schema是最新的(添加可能缺失的列) + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS op_hash VARCHAR(128)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS sign VARCHAR(512)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS timestamp TIMESTAMP") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP") + // 场景 1: 没有历史数据时启动 t.Run("NoHistoricalData", func(t *testing.T) { // 清理 diff --git a/api/persistence/pg_query_integration_test.go b/api/persistence/pg_query_integration_test.go new file mode 100644 index 0000000..8f4044c --- /dev/null +++ b/api/persistence/pg_query_integration_test.go @@ -0,0 +1,616 @@ +package persistence_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.yandata.net/iod/iod/go-trustlog/api/logger" + "go.yandata.net/iod/iod/go-trustlog/api/model" + "go.yandata.net/iod/iod/go-trustlog/api/persistence" +) + +// TestPG_Query_Integration 测试 PostgreSQL 查询功能集成 +func TestPG_Query_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping PostgreSQL query integration test in short mode") + } + + ctx := context.Background() + log := logger.NewNopLogger() + + // 连接数据库 + dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + e2eTestPGHost, e2eTestPGPort, e2eTestPGUser, e2eTestPGPassword, e2eTestPGDatabase) + + db, err := sql.Open("postgres", dsn) + if err != nil { + t.Skipf("PostgreSQL not available: %v", err) + return + } + defer db.Close() + + if err := db.Ping(); err != nil { + t.Skipf("PostgreSQL not reachable: %v", err) + return + } + + // 确保schema是最新的 + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS op_hash VARCHAR(128)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS sign VARCHAR(512)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS timestamp TIMESTAMP") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP") + + // 清理测试数据 + _, _ = db.Exec("DELETE FROM operation WHERE op_id LIKE 'pg-query-test-%'") + defer func() { + _, _ = db.Exec("DELETE FROM operation WHERE op_id LIKE 'pg-query-test-%'") + }() + + t.Log("✅ PostgreSQL connected and cleaned") + + // 创建 PersistenceManager + persistenceConfig := persistence.PersistenceConfig{ + Strategy: persistence.StrategyDBOnly, + } + manager := persistence.NewPersistenceManager(db, persistenceConfig, log) + defer manager.Close() + + repo := manager.GetOperationRepo() + + // 准备测试数据(20条记录,不同的来源、类型、状态) + baseTime := time.Now().Add(-2 * time.Hour) + testOps := []struct { + opID string + opSource string + opType string + prefix string + doid string + repo string + actor string + producer string + clientIP *string + serverIP *string + status persistence.TrustlogStatus + time time.Time + }{ + {"pg-query-test-001", "DOIP", "Create", "10.10000", "10.10000/test-repo/test-001", "test-repo", "user-1", "producer-1", strPtr("192.168.1.10"), strPtr("10.0.0.1"), persistence.StatusNotTrustlogged, baseTime}, + {"pg-query-test-002", "DOIP", "Create", "10.10000", "10.10000/test-repo/test-002", "test-repo", "user-1", "producer-1", strPtr("192.168.1.10"), strPtr("10.0.0.1"), persistence.StatusTrustlogged, baseTime.Add(10 * time.Minute)}, + {"pg-query-test-003", "DOIP", "Update", "10.10000", "10.10000/test-repo/test-003", "test-repo", "user-2", "producer-1", strPtr("192.168.1.20"), strPtr("10.0.0.1"), persistence.StatusNotTrustlogged, baseTime.Add(20 * time.Minute)}, + {"pg-query-test-004", "DOIP", "Update", "10.10000", "10.10000/test-repo/test-004", "test-repo", "user-2", "producer-2", strPtr("192.168.1.20"), strPtr("10.0.0.2"), persistence.StatusTrustlogged, baseTime.Add(30 * time.Minute)}, + {"pg-query-test-005", "DOIP", "Delete", "10.10000", "10.10000/test-repo/test-005", "test-repo", "user-3", "producer-2", nil, nil, persistence.StatusNotTrustlogged, baseTime.Add(40 * time.Minute)}, + {"pg-query-test-006", "IRP", "OC_CREATE_HANDLE", "20.1000", "20.1000/test-repo/test-001", "test-repo", "user-1", "producer-3", strPtr("192.168.2.10"), strPtr("10.0.1.1"), persistence.StatusTrustlogged, baseTime.Add(50 * time.Minute)}, + {"pg-query-test-007", "IRP", "OC_DELETE_HANDLE", "20.1000", "20.1000/test-repo/test-002", "test-repo", "user-2", "producer-3", strPtr("192.168.2.20"), strPtr("10.0.1.1"), persistence.StatusNotTrustlogged, baseTime.Add(60 * time.Minute)}, + {"pg-query-test-008", "IRP", "OC_LOOKUP_HANDLE", "20.1000", "20.1000/test-repo/test-003", "test-repo", "user-3", "producer-4", nil, strPtr("10.0.1.2"), persistence.StatusTrustlogged, baseTime.Add(70 * time.Minute)}, + {"pg-query-test-009", "DOIP", "Retrieve", "10.20000", "10.20000/test-repo/test-001", "test-repo", "user-1", "producer-1", strPtr("192.168.1.30"), nil, persistence.StatusNotTrustlogged, baseTime.Add(80 * time.Minute)}, + {"pg-query-test-010", "DOIP", "Retrieve", "10.20000", "10.20000/test-repo/test-002", "test-repo", "user-2", "producer-2", strPtr("192.168.1.40"), strPtr("10.0.0.3"), persistence.StatusTrustlogged, baseTime.Add(90 * time.Minute)}, + } + + // 插入测试数据 + for _, testOp := range testOps { + op, err := model.NewFullOperation( + model.Source(testOp.opSource), + testOp.opType, + testOp.prefix, // doPrefix + testOp.repo, // doRepository + testOp.doid, // doid + testOp.producer, // producerID + testOp.actor, // opActor + nil, // requestBody + nil, // responseBody + testOp.time, // timestamp + ) + require.NoError(t, err, "Failed to create operation %s", testOp.opID) + + op.OpID = testOp.opID + op.ClientIP = testOp.clientIP + op.ServerIP = testOp.serverIP + + err = repo.Save(ctx, op, testOp.status) + require.NoError(t, err, "Failed to save operation %s", testOp.opID) + } + + t.Log("✅ Test data created") + + // 测试1: 查询所有记录 + t.Run("Query all records", func(t *testing.T) { + req := &persistence.OperationQueryRequest{ + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(10)) + assert.GreaterOrEqual(t, len(result.Operations), 10) + t.Logf("✅ Total records: %d", result.Total) + }) + + // 测试2: 按 OpSource 筛选 + t.Run("Filter by OpSource", func(t *testing.T) { + opSource := "DOIP" + req := &persistence.OperationQueryRequest{ + OpSource: &opSource, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(7)) // 7条DOIP记录 + + for _, op := range result.Operations { + assert.Equal(t, "DOIP", string(op.OpSource)) + } + t.Logf("✅ DOIP records: %d", result.Total) + }) + + // 测试3: 按 OpType 筛选 + t.Run("Filter by OpType", func(t *testing.T) { + opType := "Create" + req := &persistence.OperationQueryRequest{ + OpType: &opType, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(2)) // 2条Create记录 + + for _, op := range result.Operations { + assert.Equal(t, "Create", op.OpType) + } + t.Logf("✅ Create records: %d", result.Total) + }) + + // 测试4: 按 TrustlogStatus 筛选 + t.Run("Filter by TrustlogStatus", func(t *testing.T) { + status := persistence.StatusTrustlogged + req := &persistence.OperationQueryRequest{ + TrustlogStatus: &status, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(5)) // 5条已存证记录 + + for _, s := range result.Statuses { + assert.Equal(t, persistence.StatusTrustlogged, s) + } + t.Logf("✅ Trustlogged records: %d", result.Total) + }) + + // 测试5: 按 DOID 模糊查询 + t.Run("Filter by DOID pattern", func(t *testing.T) { + doid := "test-001" + req := &persistence.OperationQueryRequest{ + Doid: &doid, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 3条test-001的记录 + + for _, op := range result.Operations { + assert.Contains(t, op.Doid, "test-001") + } + t.Logf("✅ DOID pattern match records: %d", result.Total) + }) + + // 测试6: 按 OpActor 筛选 + t.Run("Filter by OpActor", func(t *testing.T) { + opActor := "user-1" + req := &persistence.OperationQueryRequest{ + OpActor: &opActor, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 3条user-1的记录 + + for _, op := range result.Operations { + assert.Equal(t, "user-1", op.OpActor) + } + t.Logf("✅ OpActor records: %d", result.Total) + }) + + // 测试7: 按 ProducerID 筛选 + t.Run("Filter by ProducerID", func(t *testing.T) { + producerID := "producer-1" + req := &persistence.OperationQueryRequest{ + ProducerID: &producerID, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 3条producer-1的记录 + + for _, op := range result.Operations { + assert.Equal(t, "producer-1", op.ProducerID) + } + t.Logf("✅ ProducerID records: %d", result.Total) + }) + + // 测试8: 按 ClientIP 筛选 + t.Run("Filter by ClientIP", func(t *testing.T) { + clientIP := "192.168.1.10" + req := &persistence.OperationQueryRequest{ + ClientIP: &clientIP, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(2)) // 2条192.168.1.10的记录 + + for _, op := range result.Operations { + assert.NotNil(t, op.ClientIP) + assert.Equal(t, "192.168.1.10", *op.ClientIP) + } + t.Logf("✅ ClientIP records: %d", result.Total) + }) + + // 测试9: 按 ServerIP 筛选 + t.Run("Filter by ServerIP", func(t *testing.T) { + serverIP := "10.0.0.1" + req := &persistence.OperationQueryRequest{ + ServerIP: &serverIP, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 3条10.0.0.1的记录 + + for _, op := range result.Operations { + assert.NotNil(t, op.ServerIP) + assert.Equal(t, "10.0.0.1", *op.ServerIP) + } + t.Logf("✅ ServerIP records: %d", result.Total) + }) + + // 测试10: 时间范围查询 + t.Run("Filter by time range", func(t *testing.T) { + timeFrom := baseTime.Add(30 * time.Minute) + timeTo := baseTime.Add(70 * time.Minute) + req := &persistence.OperationQueryRequest{ + TimeFrom: &timeFrom, + TimeTo: &timeTo, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 应该至少有3条记录在这个时间范围 + t.Logf("✅ Time range records: %d", result.Total) + + // 验证返回的记录在时间范围内 + for i, op := range result.Operations { + if !((op.Timestamp.After(timeFrom) || op.Timestamp.Equal(timeFrom)) && + (op.Timestamp.Before(timeTo) || op.Timestamp.Equal(timeTo))) { + t.Logf("⚠️ Record %d out of range: timestamp=%v, from=%v, to=%v", + i, op.Timestamp, timeFrom, timeTo) + } + } + }) + + // 测试11: 组合查询(OpSource + Status) + t.Run("Combined filter (OpSource + Status)", func(t *testing.T) { + opSource := "DOIP" + status := persistence.StatusTrustlogged + req := &persistence.OperationQueryRequest{ + OpSource: &opSource, + TrustlogStatus: &status, + PageSize: 50, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 3条已存证的DOIP记录 + + for i, op := range result.Operations { + assert.Equal(t, "DOIP", string(op.OpSource)) + assert.Equal(t, persistence.StatusTrustlogged, result.Statuses[i]) + } + t.Logf("✅ Combined filter records: %d", result.Total) + }) + + // 测试12: 分页查询 + t.Run("Pagination", func(t *testing.T) { + // 第1页 + req := &persistence.OperationQueryRequest{ + PageSize: 5, + PageNumber: 1, + OrderBy: "timestamp", + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(10)) + assert.LessOrEqual(t, len(result.Operations), 5) + firstPageFirst := result.Operations[0].OpID + + // 第2页 + req.PageNumber = 2 + result, err = repo.Query(ctx, req) + require.NoError(t, err) + assert.LessOrEqual(t, len(result.Operations), 5) + + // 确保第1页和第2页的数据不重复 + if len(result.Operations) > 0 { + assert.NotEqual(t, firstPageFirst, result.Operations[0].OpID) + } + + t.Logf("✅ Pagination works correctly, total pages: %d", result.TotalPages) + }) + + // 测试13: 排序(升序/降序) + t.Run("Ordering", func(t *testing.T) { + // 升序 + reqAsc := &persistence.OperationQueryRequest{ + PageSize: 10, + PageNumber: 1, + OrderBy: "timestamp", + OrderDesc: false, + } + + resultAsc, err := repo.Query(ctx, reqAsc) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(resultAsc.Operations), 10) + + // 验证升序 + for i := 1; i < len(resultAsc.Operations); i++ { + assert.True(t, resultAsc.Operations[i].Timestamp.After(resultAsc.Operations[i-1].Timestamp) || + resultAsc.Operations[i].Timestamp.Equal(resultAsc.Operations[i-1].Timestamp)) + } + + // 降序 + reqDesc := &persistence.OperationQueryRequest{ + PageSize: 10, + PageNumber: 1, + OrderBy: "timestamp", + OrderDesc: true, + } + + resultDesc, err := repo.Query(ctx, reqDesc) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(resultDesc.Operations), 10) + + // 验证降序 + for i := 1; i < len(resultDesc.Operations); i++ { + assert.True(t, resultDesc.Operations[i].Timestamp.Before(resultDesc.Operations[i-1].Timestamp) || + resultDesc.Operations[i].Timestamp.Equal(resultDesc.Operations[i-1].Timestamp)) + } + + t.Log("✅ Ordering (ASC/DESC) works correctly") + }) + + // 测试14: Count 统计 + t.Run("Count operations", func(t *testing.T) { + // 全部统计 + req := &persistence.OperationQueryRequest{} + count, err := repo.Count(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, count, int64(10)) + t.Logf("✅ Total count: %d", count) + + // 按状态统计 + status := persistence.StatusNotTrustlogged + req.TrustlogStatus = &status + count, err = repo.Count(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, count, int64(5)) + t.Logf("✅ NOT_TRUSTLOGGED count: %d", count) + }) + + // 测试15: OpID 精确查询 + t.Run("Query by OpID", func(t *testing.T) { + opID := "pg-query-test-001" + req := &persistence.OperationQueryRequest{ + OpID: &opID, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Equal(t, int64(1), result.Total) + assert.Len(t, result.Operations, 1) + assert.Equal(t, "pg-query-test-001", result.Operations[0].OpID) + t.Log("✅ OpID query works correctly") + }) + + // 测试16: 复杂组合查询(多条件) + t.Run("Complex combined query", func(t *testing.T) { + opSource := "DOIP" + opType := "Update" + status := persistence.StatusTrustlogged + req := &persistence.OperationQueryRequest{ + OpSource: &opSource, + OpType: &opType, + TrustlogStatus: &status, + PageSize: 50, + PageNumber: 1, + OrderBy: "timestamp", + OrderDesc: true, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(1)) + + for i, op := range result.Operations { + assert.Equal(t, "DOIP", string(op.OpSource)) + assert.Equal(t, "Update", op.OpType) + assert.Equal(t, persistence.StatusTrustlogged, result.Statuses[i]) + } + t.Logf("✅ Complex query records: %d", result.Total) + }) + + t.Log("✅ All PostgreSQL query integration tests passed") +} + +// TestPG_PersistenceClient_Query_Integration 测试 PersistenceClient 的查询功能 +func TestPG_PersistenceClient_Query_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping PostgreSQL PersistenceClient query integration test in short mode") + } + + ctx := context.Background() + log := logger.NewNopLogger() + + // 连接数据库 + dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + e2eTestPGHost, e2eTestPGPort, e2eTestPGUser, e2eTestPGPassword, e2eTestPGDatabase) + + // 创建 PersistenceClient + dbConfig := persistence.DBConfig{ + DriverName: "postgres", + DSN: dsn, + MaxOpenConns: 20, + MaxIdleConns: 10, + ConnMaxLifetime: time.Hour, + } + + persistenceConfig := persistence.PersistenceConfig{ + Strategy: persistence.StrategyDBOnly, + } + + clientConfig := persistence.PersistenceClientConfig{ + Logger: log, + DBConfig: dbConfig, + PersistenceConfig: persistenceConfig, + } + + client, err := persistence.NewPersistenceClient(ctx, clientConfig) + if err != nil { + t.Skipf("PostgreSQL not available: %v", err) + return + } + defer client.Close() + + // 获取底层数据库连接进行清理和schema更新 + db, err := sql.Open("postgres", dsn) + require.NoError(t, err) + defer db.Close() + + // 清理测试数据 + _, _ = db.Exec("DELETE FROM operation WHERE op_id LIKE 'pg-client-query-%'") + defer func() { + _, _ = db.Exec("DELETE FROM operation WHERE op_id LIKE 'pg-client-query-%'") + }() + + // 确保schema是最新的 + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS op_hash VARCHAR(128)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS sign VARCHAR(512)") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS timestamp TIMESTAMP") + _, _ = db.Exec("ALTER TABLE operation ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP") + + t.Log("✅ PersistenceClient connected") + + // 创建测试数据(通过manager的repository) + manager := client.GetManager() + repo := manager.GetOperationRepo() + + for i := 0; i < 5; i++ { + op, err := model.NewFullOperation( + model.OpSourceDOIP, + string(model.OpTypeCreate), + "10.10000", // doPrefix + "client-repo", // doRepository + fmt.Sprintf("10.10000/client-repo/test-%d", i), // doid + fmt.Sprintf("client-producer-%d", i), // producerID + fmt.Sprintf("client-actor-%d", i), // opActor + nil, // requestBody + nil, // responseBody + time.Now(), // timestamp + ) + require.NoError(t, err) + op.OpID = fmt.Sprintf("pg-client-query-%03d", i) + + status := persistence.StatusNotTrustlogged + if i%2 == 0 { + status = persistence.StatusTrustlogged + } + + err = repo.Save(ctx, op, status) + require.NoError(t, err) + } + + t.Log("✅ Test data created via PersistenceClient") + + // 测试 QueryOperations + t.Run("QueryOperations", func(t *testing.T) { + req := &persistence.OperationQueryRequest{ + PageSize: 10, + PageNumber: 1, + } + + result, err := client.QueryOperations(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.GreaterOrEqual(t, result.Total, int64(5)) + t.Logf("✅ QueryOperations: total=%d", result.Total) + }) + + // 测试 CountOperations + t.Run("CountOperations", func(t *testing.T) { + req := &persistence.OperationQueryRequest{} + count, err := client.CountOperations(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, count, int64(5)) + t.Logf("✅ CountOperations: count=%d", count) + }) + + // 测试 GetOperationByID + t.Run("GetOperationByID", func(t *testing.T) { + op, status, err := client.GetOperationByID(ctx, "pg-client-query-000") + require.NoError(t, err) + assert.NotNil(t, op) + assert.Equal(t, "pg-client-query-000", op.OpID) + assert.Equal(t, persistence.StatusTrustlogged, status) + t.Log("✅ GetOperationByID works correctly") + }) + + // 测试按状态查询 + t.Run("Query by Status", func(t *testing.T) { + status := persistence.StatusTrustlogged + req := &persistence.OperationQueryRequest{ + TrustlogStatus: &status, + PageSize: 10, + PageNumber: 1, + } + + result, err := client.QueryOperations(ctx, req) + require.NoError(t, err) + assert.GreaterOrEqual(t, result.Total, int64(3)) // 3条已存证 + t.Logf("✅ Query by Status: total=%d", result.Total) + }) + + t.Log("✅ All PersistenceClient query integration tests passed") +} + +// strPtr 辅助函数:返回字符串指针 +func strPtr(s string) *string { + return &s +} + diff --git a/api/persistence/query_test.go b/api/persistence/query_test.go new file mode 100644 index 0000000..9e16197 --- /dev/null +++ b/api/persistence/query_test.go @@ -0,0 +1,290 @@ +package persistence + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.yandata.net/iod/iod/go-trustlog/api/logger" + "go.yandata.net/iod/iod/go-trustlog/api/model" +) + +func TestOperationRepository_Query(t *testing.T) { + ctx := context.Background() + log := logger.NewNopLogger() + db := setupTestDB(t) + defer db.Close() + + repo := NewOperationRepository(db, log) + + // 准备测试数据 + now := time.Now() + testOps := []struct { + opID string + opSource string + opType string + status TrustlogStatus + time time.Time + }{ + {"op-001", "DOIP", "Create", StatusNotTrustlogged, now.Add(-3 * time.Hour)}, + {"op-002", "DOIP", "Update", StatusTrustlogged, now.Add(-2 * time.Hour)}, + {"op-003", "IRP", "Create", StatusNotTrustlogged, now.Add(-1 * time.Hour)}, + {"op-004", "IRP", "Delete", StatusTrustlogged, now}, + } + + for _, testOp := range testOps { + op := createTestOperation(t, testOp.opID) + op.OpSource = model.Source(testOp.opSource) + op.OpType = testOp.opType + op.Timestamp = testOp.time + + err := repo.Save(ctx, op, testOp.status) + require.NoError(t, err) + } + + t.Run("Query all operations", func(t *testing.T) { + req := &OperationQueryRequest{ + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int64(4), result.Total) + assert.Len(t, result.Operations, 4) + assert.Len(t, result.Statuses, 4) + }) + + t.Run("Query by OpSource", func(t *testing.T) { + opSource := "DOIP" + req := &OperationQueryRequest{ + OpSource: &opSource, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Equal(t, int64(2), result.Total) + assert.Len(t, result.Operations, 2) + + for _, op := range result.Operations { + assert.Equal(t, "DOIP", string(op.OpSource)) + } + }) + + t.Run("Query by OpType", func(t *testing.T) { + opType := "Create" + req := &OperationQueryRequest{ + OpType: &opType, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Equal(t, int64(2), result.Total) + }) + + t.Run("Query by TrustlogStatus", func(t *testing.T) { + status := StatusNotTrustlogged + req := &OperationQueryRequest{ + TrustlogStatus: &status, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Equal(t, int64(2), result.Total) + + for _, s := range result.Statuses { + assert.Equal(t, StatusNotTrustlogged, s) + } + }) + + t.Run("Query with time range", func(t *testing.T) { + timeFrom := now.Add(-2*time.Hour - 30*time.Minute) + timeTo := now.Add(-30 * time.Minute) + + req := &OperationQueryRequest{ + TimeFrom: &timeFrom, + TimeTo: &timeTo, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.True(t, result.Total >= 2) // 应该包含 op-002 和 op-003 + }) + + t.Run("Query with pagination", func(t *testing.T) { + req := &OperationQueryRequest{ + PageSize: 2, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Equal(t, int64(4), result.Total) + assert.Len(t, result.Operations, 2) + assert.Equal(t, 2, result.TotalPages) + + // 查询第二页 + req.PageNumber = 2 + result, err = repo.Query(ctx, req) + require.NoError(t, err) + assert.Len(t, result.Operations, 2) + }) + + t.Run("Query with ordering DESC", func(t *testing.T) { + req := &OperationQueryRequest{ + PageSize: 10, + PageNumber: 1, + OrderBy: "timestamp", + OrderDesc: true, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Len(t, result.Operations, 4) + + // 验证降序排列 + for i := 1; i < len(result.Operations); i++ { + // 后面的时间应该早于或等于前面的时间 + assert.True(t, result.Operations[i].Timestamp.Before(result.Operations[i-1].Timestamp) || + result.Operations[i].Timestamp.Equal(result.Operations[i-1].Timestamp)) + } + }) + + t.Run("Query by OpID", func(t *testing.T) { + opID := "op-001" + req := &OperationQueryRequest{ + OpID: &opID, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.Equal(t, int64(1), result.Total) + assert.Len(t, result.Operations, 1) + assert.Equal(t, "op-001", result.Operations[0].OpID) + }) + + t.Run("Query with Doid LIKE", func(t *testing.T) { + doid := "test-repo" + req := &OperationQueryRequest{ + Doid: &doid, + PageSize: 10, + PageNumber: 1, + } + + result, err := repo.Query(ctx, req) + require.NoError(t, err) + assert.True(t, result.Total >= 4) // 所有记录的 doid 都包含 "test-repo" + }) +} + +func TestOperationRepository_Count(t *testing.T) { + ctx := context.Background() + log := logger.NewNopLogger() + db := setupTestDB(t) + defer db.Close() + + repo := NewOperationRepository(db, log) + + // 准备测试数据 + for i := 0; i < 5; i++ { + op := createTestOperation(t, fmt.Sprintf("count-op-%d", i)) + status := StatusNotTrustlogged + if i%2 == 0 { + status = StatusTrustlogged + } + err := repo.Save(ctx, op, status) + require.NoError(t, err) + } + + t.Run("Count all", func(t *testing.T) { + req := &OperationQueryRequest{} + count, err := repo.Count(ctx, req) + require.NoError(t, err) + assert.True(t, count >= 5) + }) + + t.Run("Count by status", func(t *testing.T) { + status := StatusTrustlogged + req := &OperationQueryRequest{ + TrustlogStatus: &status, + } + count, err := repo.Count(ctx, req) + require.NoError(t, err) + assert.True(t, count >= 3) // i=0,2,4 三条记录 + }) +} + +func TestPersistenceClient_QueryOperations(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + log := logger.NewNopLogger() + db := setupTestDB(t) + defer db.Close() + + // 初始化 PersistenceManager + config := PersistenceConfig{ + Strategy: StrategyDBOnly, + } + + manager := NewPersistenceManager(db, config, log) + defer manager.Close() + + // 创建 PersistenceClient + client := &PersistenceClient{ + manager: manager, + logger: log, + } + + // 准备测试数据 + for i := 0; i < 3; i++ { + op := createTestOperation(t, fmt.Sprintf("client-op-%d", i)) + err := manager.GetOperationRepo().Save(ctx, op, StatusNotTrustlogged) + require.NoError(t, err) + } + + t.Run("QueryOperations", func(t *testing.T) { + req := &OperationQueryRequest{ + PageSize: 10, + PageNumber: 1, + } + + result, err := client.QueryOperations(ctx, req) + require.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.Total >= 3) + }) + + t.Run("CountOperations", func(t *testing.T) { + req := &OperationQueryRequest{} + count, err := client.CountOperations(ctx, req) + require.NoError(t, err) + assert.True(t, count >= 3) + }) + + t.Run("GetOperationByID", func(t *testing.T) { + op, status, err := client.GetOperationByID(ctx, "client-op-0") + require.NoError(t, err) + assert.NotNil(t, op) + assert.Equal(t, "client-op-0", op.OpID) + assert.Equal(t, StatusNotTrustlogged, status) + }) +} + diff --git a/api/persistence/repository.go b/api/persistence/repository.go index db236fc..69e4f39 100644 --- a/api/persistence/repository.go +++ b/api/persistence/repository.go @@ -10,6 +10,60 @@ import ( "go.yandata.net/iod/iod/go-trustlog/api/model" ) +// OperationQueryRequest 操作记录查询请求 +type OperationQueryRequest struct { + // OpID 操作ID(精确匹配) + OpID *string + // OpSource 操作来源(精确匹配) + OpSource *string + // OpType 操作类型(精确匹配) + OpType *string + // Doid 数字对象标识符(支持 LIKE 模糊查询) + Doid *string + // ProducerID 生产者ID(精确匹配) + ProducerID *string + // OpActor 操作执行者(精确匹配) + OpActor *string + // DoPrefix DO前缀(支持 LIKE 模糊查询) + DoPrefix *string + // DoRepository DO仓库(精确匹配) + DoRepository *string + // TrustlogStatus 存证状态(精确匹配) + TrustlogStatus *TrustlogStatus + // ClientIP 客户端IP(精确匹配) + ClientIP *string + // ServerIP 服务端IP(精确匹配) + ServerIP *string + // TimeFrom 时间范围查询-开始时间(闭区间) + TimeFrom *time.Time + // TimeTo 时间范围查询-结束时间(闭区间) + TimeTo *time.Time + // PageSize 每页数量(默认20,最大1000) + PageSize int + // PageNumber 页码(从1开始) + PageNumber int + // OrderBy 排序字段(created_at, timestamp, op_id) + OrderBy string + // OrderDesc 是否降序排序(默认 false 升序) + OrderDesc bool +} + +// OperationQueryResult 操作记录查询结果 +type OperationQueryResult struct { + // Operations 操作记录列表 + Operations []*model.Operation + // Statuses 对应的存证状态列表 + Statuses []TrustlogStatus + // Total 总记录数 + Total int64 + // PageSize 每页数量 + PageSize int + // PageNumber 当前页码 + PageNumber int + // TotalPages 总页数 + TotalPages int +} + // OperationRepository 操作记录数据库仓储接口 type OperationRepository interface { // Save 保存操作记录到数据库 @@ -32,6 +86,10 @@ type OperationRepository interface { // 只有当前状态匹配 expectedStatus 时才会更新 // 返回: updated (是否更新成功), error UpdateStatusWithCAS(ctx context.Context, tx *sql.Tx, opID string, expectedStatus, newStatus TrustlogStatus) (bool, error) + // Query 根据条件查询操作记录(支持分页、筛选、排序) + Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) + // Count 统计符合条件的记录数 + Count(ctx context.Context, req *OperationQueryRequest) (int64, error) } // CursorRepository 游标仓储接口(Key-Value 模式) @@ -497,6 +555,310 @@ func (r *operationRepository) FindUntrustlogged(ctx context.Context, limit int) return operations, nil } +// Query 根据条件查询操作记录(支持分页、筛选、排序) +func (r *operationRepository) Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) { + if req == nil { + return nil, fmt.Errorf("query request cannot be nil") + } + + // 设置默认值 + pageSize := req.PageSize + if pageSize <= 0 { + pageSize = 20 + } + if pageSize > 1000 { + pageSize = 1000 + } + + pageNumber := req.PageNumber + if pageNumber <= 0 { + pageNumber = 1 + } + + orderBy := req.OrderBy + if orderBy == "" { + orderBy = "created_at" + } + // 防止 SQL 注入,只允许特定字段排序 + switch orderBy { + case "created_at", "timestamp", "op_id": + // 允许 + default: + orderBy = "created_at" + } + + orderDirection := "ASC" + if req.OrderDesc { + orderDirection = "DESC" + } + + // 构建 WHERE 子句 + var conditions []string + var args []interface{} + argIndex := 1 + + if req.OpID != nil && *req.OpID != "" { + conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex)) + args = append(args, *req.OpID) + argIndex++ + } + if req.OpSource != nil && *req.OpSource != "" { + conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex)) + args = append(args, *req.OpSource) + argIndex++ + } + if req.OpType != nil && *req.OpType != "" { + conditions = append(conditions, fmt.Sprintf("op_type = $%d", argIndex)) + args = append(args, *req.OpType) + argIndex++ + } + if req.Doid != nil && *req.Doid != "" { + conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex)) + args = append(args, "%"+*req.Doid+"%") + argIndex++ + } + if req.ProducerID != nil && *req.ProducerID != "" { + conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex)) + args = append(args, *req.ProducerID) + argIndex++ + } + if req.OpActor != nil && *req.OpActor != "" { + conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex)) + args = append(args, *req.OpActor) + argIndex++ + } + if req.DoPrefix != nil && *req.DoPrefix != "" { + conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex)) + args = append(args, "%"+*req.DoPrefix+"%") + argIndex++ + } + if req.DoRepository != nil && *req.DoRepository != "" { + conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex)) + args = append(args, *req.DoRepository) + argIndex++ + } + if req.TrustlogStatus != nil { + conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex)) + args = append(args, string(*req.TrustlogStatus)) + argIndex++ + } + if req.ClientIP != nil && *req.ClientIP != "" { + conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex)) + args = append(args, *req.ClientIP) + argIndex++ + } + if req.ServerIP != nil && *req.ServerIP != "" { + conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex)) + args = append(args, *req.ServerIP) + argIndex++ + } + if req.TimeFrom != nil { + conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) + args = append(args, *req.TimeFrom) + argIndex++ + } + if req.TimeTo != nil { + conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) + args = append(args, *req.TimeTo) + argIndex++ + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0]) + for i := 1; i < len(conditions); i++ { + whereClause += " AND " + conditions[i] + } + } + + // 先查询总数 + total, err := r.Count(ctx, req) + if err != nil { + return nil, err + } + + // 查询数据 + offset := (pageNumber - 1) * pageSize + query := r.convertPlaceholders(fmt.Sprintf(` + SELECT + op_id, op_actor, doid, producer_id, + request_body_hash, response_body_hash, + op_source, op_type, do_prefix, do_repository, + client_ip, server_ip, trustlog_status, timestamp, created_at + FROM operation + %s + ORDER BY %s %s + LIMIT $%d OFFSET $%d + `, whereClause, orderBy, orderDirection, argIndex, argIndex+1)) + + args = append(args, pageSize, offset) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + r.logger.ErrorContext(ctx, "failed to query operations", + "error", err, + ) + return nil, fmt.Errorf("failed to query operations: %w", err) + } + defer rows.Close() + + var operations []*model.Operation + var statuses []TrustlogStatus + + for rows.Next() { + var op model.Operation + var reqHash, respHash, clientIP, serverIP, statusStr sql.NullString + var createdAt time.Time + + err := rows.Scan( + &op.OpID, &op.OpActor, &op.Doid, &op.ProducerID, + &reqHash, &respHash, + &op.OpSource, &op.OpType, &op.DoPrefix, &op.DoRepository, + &clientIP, &serverIP, &statusStr, &op.Timestamp, &createdAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan operation row: %w", err) + } + + // 处理可空字段 + if reqHash.Valid { + op.RequestBodyHash = &reqHash.String + } + if respHash.Valid { + op.ResponseBodyHash = &respHash.String + } + if clientIP.Valid { + op.ClientIP = &clientIP.String + } + if serverIP.Valid { + op.ServerIP = &serverIP.String + } + + operations = append(operations, &op) + statuses = append(statuses, TrustlogStatus(statusStr.String)) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating operation rows: %w", err) + } + + // 计算总页数 + totalPages := int(total) / pageSize + if int(total)%pageSize > 0 { + totalPages++ + } + + return &OperationQueryResult{ + Operations: operations, + Statuses: statuses, + Total: total, + PageSize: pageSize, + PageNumber: pageNumber, + TotalPages: totalPages, + }, nil +} + +// Count 统计符合条件的记录数 +func (r *operationRepository) Count(ctx context.Context, req *OperationQueryRequest) (int64, error) { + if req == nil { + return 0, fmt.Errorf("query request cannot be nil") + } + + // 构建 WHERE 子句 + var conditions []string + var args []interface{} + argIndex := 1 + + if req.OpID != nil && *req.OpID != "" { + conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex)) + args = append(args, *req.OpID) + argIndex++ + } + if req.OpSource != nil && *req.OpSource != "" { + conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex)) + args = append(args, *req.OpSource) + argIndex++ + } + if req.OpType != nil && *req.OpType != "" { + conditions = append(conditions, fmt.Sprintf("op_type = $%d", argIndex)) + args = append(args, *req.OpType) + argIndex++ + } + if req.Doid != nil && *req.Doid != "" { + conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex)) + args = append(args, "%"+*req.Doid+"%") + argIndex++ + } + if req.ProducerID != nil && *req.ProducerID != "" { + conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex)) + args = append(args, *req.ProducerID) + argIndex++ + } + if req.OpActor != nil && *req.OpActor != "" { + conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex)) + args = append(args, *req.OpActor) + argIndex++ + } + if req.DoPrefix != nil && *req.DoPrefix != "" { + conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex)) + args = append(args, "%"+*req.DoPrefix+"%") + argIndex++ + } + if req.DoRepository != nil && *req.DoRepository != "" { + conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex)) + args = append(args, *req.DoRepository) + argIndex++ + } + if req.TrustlogStatus != nil { + conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex)) + args = append(args, string(*req.TrustlogStatus)) + argIndex++ + } + if req.ClientIP != nil && *req.ClientIP != "" { + conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex)) + args = append(args, *req.ClientIP) + argIndex++ + } + if req.ServerIP != nil && *req.ServerIP != "" { + conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex)) + args = append(args, *req.ServerIP) + argIndex++ + } + if req.TimeFrom != nil { + conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) + args = append(args, *req.TimeFrom) + argIndex++ + } + if req.TimeTo != nil { + conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) + args = append(args, *req.TimeTo) + argIndex++ + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0]) + for i := 1; i < len(conditions); i++ { + whereClause += " AND " + conditions[i] + } + } + + query := r.convertPlaceholders(fmt.Sprintf(` + SELECT COUNT(*) FROM operation %s + `, whereClause)) + + var count int64 + err := r.db.QueryRowContext(ctx, query, args...).Scan(&count) + if err != nil { + r.logger.ErrorContext(ctx, "failed to count operations", + "error", err, + ) + return 0, fmt.Errorf("failed to count operations: %w", err) + } + + return count, nil +} + // cursorRepository 游标仓储实现 type cursorRepository struct { db *sql.DB