feat: 添加 Operation 查询功能及完整测试
主要功能: - 新增 OperationQueryRequest/OperationQueryResult 结构体 - 实现 Repository.Query() - 支持多条件筛选、分页、排序 - 实现 Repository.Count() - 统计记录数 - 新增 PersistenceClient.QueryOperations/CountOperations/GetOperationByID 查询功能: - 支持按 OpID、OpSource、OpType、Doid 等字段筛选 - 支持模糊查询(Doid、DoPrefix) - 支持时间范围查询(TimeFrom/TimeTo) - 支持 IP 地址筛选(ClientIP、ServerIP) - 支持按 TrustlogStatus 筛选 - 支持组合查询 - 支持分页(PageSize、PageNumber) - 支持排序(OrderBy、OrderDesc) 测试覆盖: - ✅ query_test.go - 查询功能单元测试 - ✅ pg_query_integration_test.go - PostgreSQL 集成测试(16个测试用例) * Query all records * Filter by OpSource/OpType/Status/Actor/Producer/IP * DOID 模糊查询 * 时间范围查询 * 分页测试 * 排序测试(升序/降序) * 组合查询 * Count 统计 * PersistenceClient 接口测试 修复: - 修复 TestClusterSafety_MultipleCursorWorkers - 添加缺失字段 - 修复 TestCursorInitialization - 确保 schema 最新 - 添加自动 schema 更新(ALTER TABLE IF NOT EXISTS) 测试结果: - ✅ 所有单元测试通过(100%) - ✅ 所有集成测试通过(PostgreSQL、Pulsar、E2E) - ✅ Query 功能测试通过(16个测试用例)
This commit is contained in:
@@ -10,6 +10,60 @@ import (
|
||||
"go.yandata.net/iod/iod/go-trustlog/api/model"
|
||||
)
|
||||
|
||||
// OperationQueryRequest 操作记录查询请求
|
||||
type OperationQueryRequest struct {
|
||||
// OpID 操作ID(精确匹配)
|
||||
OpID *string
|
||||
// OpSource 操作来源(精确匹配)
|
||||
OpSource *string
|
||||
// OpType 操作类型(精确匹配)
|
||||
OpType *string
|
||||
// Doid 数字对象标识符(支持 LIKE 模糊查询)
|
||||
Doid *string
|
||||
// ProducerID 生产者ID(精确匹配)
|
||||
ProducerID *string
|
||||
// OpActor 操作执行者(精确匹配)
|
||||
OpActor *string
|
||||
// DoPrefix DO前缀(支持 LIKE 模糊查询)
|
||||
DoPrefix *string
|
||||
// DoRepository DO仓库(精确匹配)
|
||||
DoRepository *string
|
||||
// TrustlogStatus 存证状态(精确匹配)
|
||||
TrustlogStatus *TrustlogStatus
|
||||
// ClientIP 客户端IP(精确匹配)
|
||||
ClientIP *string
|
||||
// ServerIP 服务端IP(精确匹配)
|
||||
ServerIP *string
|
||||
// TimeFrom 时间范围查询-开始时间(闭区间)
|
||||
TimeFrom *time.Time
|
||||
// TimeTo 时间范围查询-结束时间(闭区间)
|
||||
TimeTo *time.Time
|
||||
// PageSize 每页数量(默认20,最大1000)
|
||||
PageSize int
|
||||
// PageNumber 页码(从1开始)
|
||||
PageNumber int
|
||||
// OrderBy 排序字段(created_at, timestamp, op_id)
|
||||
OrderBy string
|
||||
// OrderDesc 是否降序排序(默认 false 升序)
|
||||
OrderDesc bool
|
||||
}
|
||||
|
||||
// OperationQueryResult 操作记录查询结果
|
||||
type OperationQueryResult struct {
|
||||
// Operations 操作记录列表
|
||||
Operations []*model.Operation
|
||||
// Statuses 对应的存证状态列表
|
||||
Statuses []TrustlogStatus
|
||||
// Total 总记录数
|
||||
Total int64
|
||||
// PageSize 每页数量
|
||||
PageSize int
|
||||
// PageNumber 当前页码
|
||||
PageNumber int
|
||||
// TotalPages 总页数
|
||||
TotalPages int
|
||||
}
|
||||
|
||||
// OperationRepository 操作记录数据库仓储接口
|
||||
type OperationRepository interface {
|
||||
// Save 保存操作记录到数据库
|
||||
@@ -32,6 +86,10 @@ type OperationRepository interface {
|
||||
// 只有当前状态匹配 expectedStatus 时才会更新
|
||||
// 返回: updated (是否更新成功), error
|
||||
UpdateStatusWithCAS(ctx context.Context, tx *sql.Tx, opID string, expectedStatus, newStatus TrustlogStatus) (bool, error)
|
||||
// Query 根据条件查询操作记录(支持分页、筛选、排序)
|
||||
Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error)
|
||||
// Count 统计符合条件的记录数
|
||||
Count(ctx context.Context, req *OperationQueryRequest) (int64, error)
|
||||
}
|
||||
|
||||
// CursorRepository 游标仓储接口(Key-Value 模式)
|
||||
@@ -497,6 +555,310 @@ func (r *operationRepository) FindUntrustlogged(ctx context.Context, limit int)
|
||||
return operations, nil
|
||||
}
|
||||
|
||||
// Query 根据条件查询操作记录(支持分页、筛选、排序)
|
||||
func (r *operationRepository) Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("query request cannot be nil")
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
pageSize := req.PageSize
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
pageNumber := req.PageNumber
|
||||
if pageNumber <= 0 {
|
||||
pageNumber = 1
|
||||
}
|
||||
|
||||
orderBy := req.OrderBy
|
||||
if orderBy == "" {
|
||||
orderBy = "created_at"
|
||||
}
|
||||
// 防止 SQL 注入,只允许特定字段排序
|
||||
switch orderBy {
|
||||
case "created_at", "timestamp", "op_id":
|
||||
// 允许
|
||||
default:
|
||||
orderBy = "created_at"
|
||||
}
|
||||
|
||||
orderDirection := "ASC"
|
||||
if req.OrderDesc {
|
||||
orderDirection = "DESC"
|
||||
}
|
||||
|
||||
// 构建 WHERE 子句
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
if req.OpID != nil && *req.OpID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex))
|
||||
args = append(args, *req.OpID)
|
||||
argIndex++
|
||||
}
|
||||
if req.OpSource != nil && *req.OpSource != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex))
|
||||
args = append(args, *req.OpSource)
|
||||
argIndex++
|
||||
}
|
||||
if req.OpType != nil && *req.OpType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_type = $%d", argIndex))
|
||||
args = append(args, *req.OpType)
|
||||
argIndex++
|
||||
}
|
||||
if req.Doid != nil && *req.Doid != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex))
|
||||
args = append(args, "%"+*req.Doid+"%")
|
||||
argIndex++
|
||||
}
|
||||
if req.ProducerID != nil && *req.ProducerID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex))
|
||||
args = append(args, *req.ProducerID)
|
||||
argIndex++
|
||||
}
|
||||
if req.OpActor != nil && *req.OpActor != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex))
|
||||
args = append(args, *req.OpActor)
|
||||
argIndex++
|
||||
}
|
||||
if req.DoPrefix != nil && *req.DoPrefix != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex))
|
||||
args = append(args, "%"+*req.DoPrefix+"%")
|
||||
argIndex++
|
||||
}
|
||||
if req.DoRepository != nil && *req.DoRepository != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex))
|
||||
args = append(args, *req.DoRepository)
|
||||
argIndex++
|
||||
}
|
||||
if req.TrustlogStatus != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex))
|
||||
args = append(args, string(*req.TrustlogStatus))
|
||||
argIndex++
|
||||
}
|
||||
if req.ClientIP != nil && *req.ClientIP != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex))
|
||||
args = append(args, *req.ClientIP)
|
||||
argIndex++
|
||||
}
|
||||
if req.ServerIP != nil && *req.ServerIP != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex))
|
||||
args = append(args, *req.ServerIP)
|
||||
argIndex++
|
||||
}
|
||||
if req.TimeFrom != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
|
||||
args = append(args, *req.TimeFrom)
|
||||
argIndex++
|
||||
}
|
||||
if req.TimeTo != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
|
||||
args = append(args, *req.TimeTo)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0])
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
whereClause += " AND " + conditions[i]
|
||||
}
|
||||
}
|
||||
|
||||
// 先查询总数
|
||||
total, err := r.Count(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查询数据
|
||||
offset := (pageNumber - 1) * pageSize
|
||||
query := r.convertPlaceholders(fmt.Sprintf(`
|
||||
SELECT
|
||||
op_id, op_actor, doid, producer_id,
|
||||
request_body_hash, response_body_hash,
|
||||
op_source, op_type, do_prefix, do_repository,
|
||||
client_ip, server_ip, trustlog_status, timestamp, created_at
|
||||
FROM operation
|
||||
%s
|
||||
ORDER BY %s %s
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, orderBy, orderDirection, argIndex, argIndex+1))
|
||||
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, "failed to query operations",
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to query operations: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var operations []*model.Operation
|
||||
var statuses []TrustlogStatus
|
||||
|
||||
for rows.Next() {
|
||||
var op model.Operation
|
||||
var reqHash, respHash, clientIP, serverIP, statusStr sql.NullString
|
||||
var createdAt time.Time
|
||||
|
||||
err := rows.Scan(
|
||||
&op.OpID, &op.OpActor, &op.Doid, &op.ProducerID,
|
||||
&reqHash, &respHash,
|
||||
&op.OpSource, &op.OpType, &op.DoPrefix, &op.DoRepository,
|
||||
&clientIP, &serverIP, &statusStr, &op.Timestamp, &createdAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan operation row: %w", err)
|
||||
}
|
||||
|
||||
// 处理可空字段
|
||||
if reqHash.Valid {
|
||||
op.RequestBodyHash = &reqHash.String
|
||||
}
|
||||
if respHash.Valid {
|
||||
op.ResponseBodyHash = &respHash.String
|
||||
}
|
||||
if clientIP.Valid {
|
||||
op.ClientIP = &clientIP.String
|
||||
}
|
||||
if serverIP.Valid {
|
||||
op.ServerIP = &serverIP.String
|
||||
}
|
||||
|
||||
operations = append(operations, &op)
|
||||
statuses = append(statuses, TrustlogStatus(statusStr.String))
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating operation rows: %w", err)
|
||||
}
|
||||
|
||||
// 计算总页数
|
||||
totalPages := int(total) / pageSize
|
||||
if int(total)%pageSize > 0 {
|
||||
totalPages++
|
||||
}
|
||||
|
||||
return &OperationQueryResult{
|
||||
Operations: operations,
|
||||
Statuses: statuses,
|
||||
Total: total,
|
||||
PageSize: pageSize,
|
||||
PageNumber: pageNumber,
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Count 统计符合条件的记录数
|
||||
func (r *operationRepository) Count(ctx context.Context, req *OperationQueryRequest) (int64, error) {
|
||||
if req == nil {
|
||||
return 0, fmt.Errorf("query request cannot be nil")
|
||||
}
|
||||
|
||||
// 构建 WHERE 子句
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
argIndex := 1
|
||||
|
||||
if req.OpID != nil && *req.OpID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex))
|
||||
args = append(args, *req.OpID)
|
||||
argIndex++
|
||||
}
|
||||
if req.OpSource != nil && *req.OpSource != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex))
|
||||
args = append(args, *req.OpSource)
|
||||
argIndex++
|
||||
}
|
||||
if req.OpType != nil && *req.OpType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_type = $%d", argIndex))
|
||||
args = append(args, *req.OpType)
|
||||
argIndex++
|
||||
}
|
||||
if req.Doid != nil && *req.Doid != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex))
|
||||
args = append(args, "%"+*req.Doid+"%")
|
||||
argIndex++
|
||||
}
|
||||
if req.ProducerID != nil && *req.ProducerID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex))
|
||||
args = append(args, *req.ProducerID)
|
||||
argIndex++
|
||||
}
|
||||
if req.OpActor != nil && *req.OpActor != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex))
|
||||
args = append(args, *req.OpActor)
|
||||
argIndex++
|
||||
}
|
||||
if req.DoPrefix != nil && *req.DoPrefix != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex))
|
||||
args = append(args, "%"+*req.DoPrefix+"%")
|
||||
argIndex++
|
||||
}
|
||||
if req.DoRepository != nil && *req.DoRepository != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex))
|
||||
args = append(args, *req.DoRepository)
|
||||
argIndex++
|
||||
}
|
||||
if req.TrustlogStatus != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex))
|
||||
args = append(args, string(*req.TrustlogStatus))
|
||||
argIndex++
|
||||
}
|
||||
if req.ClientIP != nil && *req.ClientIP != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex))
|
||||
args = append(args, *req.ClientIP)
|
||||
argIndex++
|
||||
}
|
||||
if req.ServerIP != nil && *req.ServerIP != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex))
|
||||
args = append(args, *req.ServerIP)
|
||||
argIndex++
|
||||
}
|
||||
if req.TimeFrom != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
|
||||
args = append(args, *req.TimeFrom)
|
||||
argIndex++
|
||||
}
|
||||
if req.TimeTo != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
|
||||
args = append(args, *req.TimeTo)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0])
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
whereClause += " AND " + conditions[i]
|
||||
}
|
||||
}
|
||||
|
||||
query := r.convertPlaceholders(fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM operation %s
|
||||
`, whereClause))
|
||||
|
||||
var count int64
|
||||
err := r.db.QueryRowContext(ctx, query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
r.logger.ErrorContext(ctx, "failed to count operations",
|
||||
"error", err,
|
||||
)
|
||||
return 0, fmt.Errorf("failed to count operations: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// cursorRepository 游标仓储实现
|
||||
type cursorRepository struct {
|
||||
db *sql.DB
|
||||
|
||||
Reference in New Issue
Block a user