Files
go-trustlog/api/persistence/repository.go
ryan fb182adef4 feat: OpType重构为OpCode (int32) - 完整实现
🎯 核心变更:
- OpType (string) → OpCode (int32)
- 20+ OpCode枚举常量 (基于DOIP/IRP标准)
- 类型安全 + 性能优化

📊 影响范围:
- 核心模型: Operation结构体、CBOR序列化
- 数据库: schema.go + SQL DDL (PostgreSQL/MySQL/SQLite)
- 持久化: repository.go查询、cursor_worker.go
- API接口: Protobuf定义 + gRPC客户端
- 测试代码: 60+ 测试文件更新

 测试结果:
- 通过率: 100% (所有87个测试用例)
- 总体覆盖率: 53.7%
- 核心包覆盖率: logger(100%), highclient(95.3%), model(79.1%)

📝 文档:
- 精简README (1056行→489行,减少54%)
- 完整的OpCode枚举说明
- 三种持久化策略示例
- 数据库表结构和架构图

🔧 技术细节:
- 类型转换: string(OpCode) → int32(OpCode)
- SQL参数: 字符串值 → 整数值
- Protobuf: op_type string → op_code int32
- 测试断言: 字符串比较 → 常量比较

🎉 质量保证:
- 零编译错误
- 100%测试通过
- PostgreSQL/Pulsar集成测试验证
- 分布式并发安全测试通过
2025-12-26 13:47:55 +08:00

1179 lines
32 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package persistence
import (
"context"
"database/sql"
"fmt"
"time"
"go.yandata.net/iod/iod/go-trustlog/api/logger"
"go.yandata.net/iod/iod/go-trustlog/api/model"
)
// OperationQueryRequest 操作记录查询请求
type OperationQueryRequest struct {
// OpID 操作ID精确匹配
OpID *string
// OpSource 操作来源(精确匹配)
OpSource *string
// OpCode 操作代码精确匹配int32
OpCode *int32
// Doid 数字对象标识符(支持 LIKE 模糊查询)
Doid *string
// ProducerID 生产者ID精确匹配
ProducerID *string
// OpActor 操作执行者(精确匹配)
OpActor *string
// DoPrefix DO前缀支持 LIKE 模糊查询)
DoPrefix *string
// DoRepository DO仓库精确匹配
DoRepository *string
// TrustlogStatus 存证状态(精确匹配)
TrustlogStatus *TrustlogStatus
// ClientIP 客户端IP精确匹配
ClientIP *string
// ServerIP 服务端IP精确匹配
ServerIP *string
// TimeFrom 时间范围查询-开始时间(闭区间)
TimeFrom *time.Time
// TimeTo 时间范围查询-结束时间(闭区间)
TimeTo *time.Time
// PageSize 每页数量默认20最大1000
PageSize int
// PageNumber 页码从1开始
PageNumber int
// OrderBy 排序字段created_at, timestamp, op_id
OrderBy string
// OrderDesc 是否降序排序(默认 false 升序)
OrderDesc bool
}
// OperationQueryResult 操作记录查询结果
type OperationQueryResult struct {
// Operations 操作记录列表
Operations []*model.Operation
// Statuses 对应的存证状态列表
Statuses []TrustlogStatus
// Total 总记录数
Total int64
// PageSize 每页数量
PageSize int
// PageNumber 当前页码
PageNumber int
// TotalPages 总页数
TotalPages int
}
// OperationRepository 操作记录数据库仓储接口
type OperationRepository interface {
// Save 保存操作记录到数据库
Save(ctx context.Context, op *model.Operation, status TrustlogStatus) error
// SaveTx 在事务中保存操作记录
SaveTx(ctx context.Context, tx *sql.Tx, op *model.Operation, status TrustlogStatus) error
// UpdateStatus 更新操作记录的存证状态
UpdateStatus(ctx context.Context, opID string, status TrustlogStatus) error
// UpdateStatusTx 在事务中更新操作记录的存证状态
UpdateStatusTx(ctx context.Context, tx *sql.Tx, opID string, status TrustlogStatus) error
// FindByID 根据 OpID 查询操作记录
FindByID(ctx context.Context, opID string) (*model.Operation, TrustlogStatus, error)
// FindUntrustlogged 查询未存证的操作记录(用于重试机制)
FindUntrustlogged(ctx context.Context, limit int) ([]*model.Operation, error)
// FindUntrustloggedWithLock 查找未存证的操作(支持集群并发安全)
// 使用 SELECT FOR UPDATE SKIP LOCKED 确保多个 worker 不会处理相同的记录
// 返回: operations, opIDs, error
FindUntrustloggedWithLock(ctx context.Context, tx *sql.Tx, limit int) ([]*model.Operation, []string, error)
// UpdateStatusWithCAS 使用 CAS (Compare-And-Set) 更新状态
// 只有当前状态匹配 expectedStatus 时才会更新
// 返回: updated (是否更新成功), error
UpdateStatusWithCAS(ctx context.Context, tx *sql.Tx, opID string, expectedStatus, newStatus TrustlogStatus) (bool, error)
// Query 根据条件查询操作记录(支持分页、筛选、排序)
Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error)
// Count 统计符合条件的记录数
Count(ctx context.Context, req *OperationQueryRequest) (int64, error)
}
// CursorRepository 游标仓储接口Key-Value 模式)
type CursorRepository interface {
// GetCursor 获取游标值
GetCursor(ctx context.Context, cursorKey string) (string, error)
// UpdateCursor 更新游标值
UpdateCursor(ctx context.Context, cursorKey string, cursorValue string) error
// UpdateCursorTx 在事务中更新游标值
UpdateCursorTx(ctx context.Context, tx *sql.Tx, cursorKey string, cursorValue string) error
// InitCursor 初始化游标(如果不存在)
InitCursor(ctx context.Context, cursorKey string, initialValue string) error
}
// RetryRepository 重试仓储接口
type RetryRepository interface {
// AddRetry 添加重试记录
AddRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error
// AddRetryTx 在事务中添加重试记录
AddRetryTx(ctx context.Context, tx *sql.Tx, opID string, errorMsg string, nextRetryAt time.Time) error
// IncrementRetry 增加重试次数
IncrementRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error
// MarkAsDeadLetter 标记为死信
MarkAsDeadLetter(ctx context.Context, opID string, errorMsg string) error
// FindPendingRetries 查找待重试的记录
FindPendingRetries(ctx context.Context, limit int) ([]RetryRecord, error)
// DeleteRetry 删除重试记录(成功后清理)
DeleteRetry(ctx context.Context, opID string) error
}
// RetryRecord 重试记录
type RetryRecord struct {
OpID string
RetryCount int
RetryStatus RetryStatus
LastRetryAt *time.Time
NextRetryAt *time.Time
ErrorMessage string
CreatedAt time.Time
UpdatedAt time.Time
}
// operationRepository 操作记录仓储实现
type operationRepository struct {
db *sql.DB
logger logger.Logger
driverName string
}
// detectDriverName 检测数据库驱动名
func detectDriverName(db *sql.DB) string {
if db == nil {
return "sqlite3"
}
// 尝试执行 PostgreSQL 特有的查询
var version string
err := db.QueryRow("SELECT version()").Scan(&version)
if err == nil && len(version) >= 10 && version[:10] == "PostgreSQL" {
return "postgres"
}
return "sqlite3" // 默认
}
// convertPlaceholdersForDriver 将 ? 占位符转换为适合数据库的占位符
func convertPlaceholdersForDriver(query, driverName string) string {
if driverName == "postgres" {
// PostgreSQL 使用 $1, $2, $3...
count := 1
result := ""
for i := 0; i < len(query); i++ {
if query[i] == '?' {
result += fmt.Sprintf("$%d", count)
count++
} else {
result += string(query[i])
}
}
return result
}
// 其他数据库SQLite, MySQL使用 ?
return query
}
// NewOperationRepository 创建操作记录仓储
func NewOperationRepository(db *sql.DB, log logger.Logger) OperationRepository {
driverName := detectDriverName(db)
return &operationRepository{
db: db,
logger: log,
driverName: driverName,
}
}
// convertPlaceholders 将 ? 占位符转换为适合数据库的占位符
func (r *operationRepository) convertPlaceholders(query string) string {
return convertPlaceholdersForDriver(query, r.driverName)
}
func (r *operationRepository) Save(ctx context.Context, op *model.Operation, status TrustlogStatus) error {
return r.SaveTx(ctx, nil, op, status)
}
func (r *operationRepository) SaveTx(ctx context.Context, tx *sql.Tx, op *model.Operation, status TrustlogStatus) error {
query := r.convertPlaceholders(`
INSERT INTO operation (
op_id, op_actor, doid, producer_id,
request_body_hash, response_body_hash,
op_source, op_code, do_prefix, do_repository,
client_ip, server_ip, trustlog_status, timestamp
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
var reqHash, respHash, clientIP, serverIP sql.NullString
if op.RequestBodyHash != nil {
reqHash = sql.NullString{String: *op.RequestBodyHash, Valid: true}
}
if op.ResponseBodyHash != nil {
respHash = sql.NullString{String: *op.ResponseBodyHash, Valid: true}
}
if op.ClientIP != nil {
clientIP = sql.NullString{String: *op.ClientIP, Valid: true}
}
if op.ServerIP != nil {
serverIP = sql.NullString{String: *op.ServerIP, Valid: true}
}
args := []interface{}{
op.OpID,
op.OpActor,
op.Doid,
op.ProducerID,
reqHash,
respHash,
string(op.OpSource),
int32(op.OpCode),
op.DoPrefix,
op.DoRepository,
clientIP,
serverIP,
string(status),
op.Timestamp,
}
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, args...)
} else {
_, err = r.db.ExecContext(ctx, query, args...)
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to save operation",
"opID", op.OpID,
"error", err,
)
return fmt.Errorf("failed to save operation: %w", err)
}
r.logger.DebugContext(ctx, "operation saved to database",
"opID", op.OpID,
"status", status,
)
return nil
}
func (r *operationRepository) UpdateStatus(ctx context.Context, opID string, status TrustlogStatus) error {
return r.UpdateStatusTx(ctx, nil, opID, status)
}
func (r *operationRepository) UpdateStatusTx(ctx context.Context, tx *sql.Tx, opID string, status TrustlogStatus) error {
query := r.convertPlaceholders(`UPDATE operation SET trustlog_status = ? WHERE op_id = ?`)
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, string(status), opID)
} else {
_, err = r.db.ExecContext(ctx, query, string(status), opID)
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to update operation status",
"opID", opID,
"status", status,
"error", err,
)
return fmt.Errorf("failed to update operation status: %w", err)
}
r.logger.DebugContext(ctx, "operation status updated",
"opID", opID,
"status", status,
)
return nil
}
func (r *operationRepository) FindByID(ctx context.Context, opID string) (*model.Operation, TrustlogStatus, error) {
query := r.convertPlaceholders(`
SELECT
op_id, op_actor, doid, producer_id,
request_body_hash, response_body_hash,
op_source, op_code, do_prefix, do_repository,
client_ip, server_ip, trustlog_status, timestamp
FROM operation
WHERE op_id = ?
`)
var op model.Operation
var statusStr string
var reqHash, respHash, clientIP, serverIP sql.NullString
err := r.db.QueryRowContext(ctx, query, opID).Scan(
&op.OpID,
&op.OpActor,
&op.Doid,
&op.ProducerID,
&reqHash,
&respHash,
&op.OpSource,
&op.OpCode,
&op.DoPrefix,
&op.DoRepository,
&clientIP,
&serverIP,
&statusStr,
&op.Timestamp,
)
if err == sql.ErrNoRows {
return nil, "", fmt.Errorf("operation not found: %s", opID)
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to find operation",
"opID", opID,
"error", err,
)
return nil, "", fmt.Errorf("failed to find operation: %w", err)
}
if reqHash.Valid {
op.RequestBodyHash = &reqHash.String
}
if respHash.Valid {
op.ResponseBodyHash = &respHash.String
}
if clientIP.Valid {
op.ClientIP = &clientIP.String
}
if serverIP.Valid {
op.ServerIP = &serverIP.String
}
return &op, TrustlogStatus(statusStr), nil
}
// FindUntrustloggedWithLock 查找未存证的操作(支持集群并发安全)
// 使用 SELECT FOR UPDATE SKIP LOCKED 确保多个 worker 不会处理相同的记录
func (r *operationRepository) FindUntrustloggedWithLock(ctx context.Context, tx *sql.Tx, limit int) ([]*model.Operation, []string, error) {
// 使用 FOR UPDATE SKIP LOCKED 锁定记录
// SKIP LOCKED: 跳过已被其他事务锁定的行,避免等待
query := r.convertPlaceholders(`
SELECT
op_id, op_actor, doid, producer_id,
request_body_hash, response_body_hash,
op_source, op_code, do_prefix, do_repository,
client_ip, server_ip, timestamp
FROM operation
WHERE trustlog_status = ?
ORDER BY timestamp ASC
LIMIT ?
FOR UPDATE SKIP LOCKED
`)
var rows *sql.Rows
var err error
if tx != nil {
rows, err = tx.QueryContext(ctx, query, string(StatusNotTrustlogged), limit)
} else {
rows, err = r.db.QueryContext(ctx, query, string(StatusNotTrustlogged), limit)
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to find untrustlogged operations with lock",
"error", err,
)
return nil, nil, fmt.Errorf("failed to find untrustlogged operations: %w", err)
}
defer rows.Close()
var operations []*model.Operation
var opIDs []string
for rows.Next() {
var op model.Operation
var reqHash, respHash, clientIP, serverIP sql.NullString
err := rows.Scan(
&op.OpID,
&op.OpActor,
&op.Doid,
&op.ProducerID,
&reqHash,
&respHash,
&op.OpSource,
&op.OpCode,
&op.DoPrefix,
&op.DoRepository,
&clientIP,
&serverIP,
&op.Timestamp,
)
if err != nil {
r.logger.ErrorContext(ctx, "failed to scan operation",
"error", err,
)
continue
}
if reqHash.Valid {
op.RequestBodyHash = &reqHash.String
}
if respHash.Valid {
op.ResponseBodyHash = &respHash.String
}
if clientIP.Valid {
op.ClientIP = &clientIP.String
}
if serverIP.Valid {
op.ServerIP = &serverIP.String
}
operations = append(operations, &op)
opIDs = append(opIDs, op.OpID)
}
if err := rows.Err(); err != nil {
r.logger.ErrorContext(ctx, "error iterating rows",
"error", err,
)
return nil, nil, fmt.Errorf("error iterating rows: %w", err)
}
return operations, opIDs, nil
}
// UpdateStatusWithCAS 使用 CAS (Compare-And-Set) 更新状态
// 只有当前状态匹配 expectedStatus 时才会更新,确保并发安全
func (r *operationRepository) UpdateStatusWithCAS(ctx context.Context, tx *sql.Tx, opID string, expectedStatus, newStatus TrustlogStatus) (bool, error) {
query := r.convertPlaceholders(`
UPDATE operation
SET trustlog_status = ?
WHERE op_id = ? AND trustlog_status = ?
`)
var result sql.Result
var err error
if tx != nil {
result, err = tx.ExecContext(ctx, query, string(newStatus), opID, string(expectedStatus))
} else {
result, err = r.db.ExecContext(ctx, query, string(newStatus), opID, string(expectedStatus))
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to update operation status with CAS",
"opID", opID,
"expectedStatus", expectedStatus,
"newStatus", newStatus,
"error", err,
)
return false, fmt.Errorf("failed to update operation status: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return false, fmt.Errorf("failed to get rows affected: %w", err)
}
// 如果影响行数为 0说明状态已被其他 worker 修改
if rowsAffected == 0 {
r.logger.WarnContext(ctx, "CAS update failed: status already changed by another worker",
"opID", opID,
"expectedStatus", expectedStatus,
)
return false, nil
}
r.logger.DebugContext(ctx, "operation status updated with CAS",
"opID", opID,
"expectedStatus", expectedStatus,
"newStatus", newStatus,
)
return true, nil
}
func (r *operationRepository) FindUntrustlogged(ctx context.Context, limit int) ([]*model.Operation, error) {
query := r.convertPlaceholders(`
SELECT
op_id, op_actor, doid, producer_id,
request_body_hash, response_body_hash,
op_source, op_code, do_prefix, do_repository,
client_ip, server_ip, timestamp
FROM operation
WHERE trustlog_status = ?
ORDER BY timestamp ASC
LIMIT ?
`)
rows, err := r.db.QueryContext(ctx, query, string(StatusNotTrustlogged), limit)
if err != nil {
r.logger.ErrorContext(ctx, "failed to find untrustlogged operations",
"error", err,
)
return nil, fmt.Errorf("failed to find untrustlogged operations: %w", err)
}
defer rows.Close()
var operations []*model.Operation
for rows.Next() {
var op model.Operation
var reqHash, respHash, clientIP, serverIP sql.NullString
err := rows.Scan(
&op.OpID,
&op.OpActor,
&op.Doid,
&op.ProducerID,
&reqHash,
&respHash,
&op.OpSource,
&op.OpCode,
&op.DoPrefix,
&op.DoRepository,
&clientIP,
&serverIP,
&op.Timestamp,
)
if err != nil {
r.logger.ErrorContext(ctx, "failed to scan operation row",
"error", err,
)
return nil, fmt.Errorf("failed to scan operation row: %w", err)
}
if reqHash.Valid {
op.RequestBodyHash = &reqHash.String
}
if respHash.Valid {
op.ResponseBodyHash = &respHash.String
}
if clientIP.Valid {
op.ClientIP = &clientIP.String
}
if serverIP.Valid {
op.ServerIP = &serverIP.String
}
operations = append(operations, &op)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating operation rows: %w", err)
}
return operations, nil
}
// Query 根据条件查询操作记录(支持分页、筛选、排序)
func (r *operationRepository) Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) {
if req == nil {
return nil, fmt.Errorf("query request cannot be nil")
}
// 设置默认值
pageSize := req.PageSize
if pageSize <= 0 {
pageSize = 20
}
if pageSize > 1000 {
pageSize = 1000
}
pageNumber := req.PageNumber
if pageNumber <= 0 {
pageNumber = 1
}
orderBy := req.OrderBy
if orderBy == "" {
orderBy = "created_at"
}
// 防止 SQL 注入,只允许特定字段排序
switch orderBy {
case "created_at", "timestamp", "op_id":
// 允许
default:
orderBy = "created_at"
}
orderDirection := "ASC"
if req.OrderDesc {
orderDirection = "DESC"
}
// 构建 WHERE 子句
var conditions []string
var args []interface{}
argIndex := 1
if req.OpID != nil && *req.OpID != "" {
conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex))
args = append(args, *req.OpID)
argIndex++
}
if req.OpSource != nil && *req.OpSource != "" {
conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex))
args = append(args, *req.OpSource)
argIndex++
}
if req.OpCode != nil {
conditions = append(conditions, fmt.Sprintf("op_code = $%d", argIndex))
args = append(args, *req.OpCode)
argIndex++
}
if req.Doid != nil && *req.Doid != "" {
conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex))
args = append(args, "%"+*req.Doid+"%")
argIndex++
}
if req.ProducerID != nil && *req.ProducerID != "" {
conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex))
args = append(args, *req.ProducerID)
argIndex++
}
if req.OpActor != nil && *req.OpActor != "" {
conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex))
args = append(args, *req.OpActor)
argIndex++
}
if req.DoPrefix != nil && *req.DoPrefix != "" {
conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex))
args = append(args, "%"+*req.DoPrefix+"%")
argIndex++
}
if req.DoRepository != nil && *req.DoRepository != "" {
conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex))
args = append(args, *req.DoRepository)
argIndex++
}
if req.TrustlogStatus != nil {
conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex))
args = append(args, string(*req.TrustlogStatus))
argIndex++
}
if req.ClientIP != nil && *req.ClientIP != "" {
conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex))
args = append(args, *req.ClientIP)
argIndex++
}
if req.ServerIP != nil && *req.ServerIP != "" {
conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex))
args = append(args, *req.ServerIP)
argIndex++
}
if req.TimeFrom != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *req.TimeFrom)
argIndex++
}
if req.TimeTo != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *req.TimeTo)
argIndex++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0])
for i := 1; i < len(conditions); i++ {
whereClause += " AND " + conditions[i]
}
}
// 先查询总数
total, err := r.Count(ctx, req)
if err != nil {
return nil, err
}
// 查询数据
offset := (pageNumber - 1) * pageSize
query := r.convertPlaceholders(fmt.Sprintf(`
SELECT
op_id, op_actor, doid, producer_id,
request_body_hash, response_body_hash,
op_source, op_code, do_prefix, do_repository,
client_ip, server_ip, trustlog_status, timestamp, created_at
FROM operation
%s
ORDER BY %s %s
LIMIT $%d OFFSET $%d
`, whereClause, orderBy, orderDirection, argIndex, argIndex+1))
args = append(args, pageSize, offset)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
r.logger.ErrorContext(ctx, "failed to query operations",
"error", err,
)
return nil, fmt.Errorf("failed to query operations: %w", err)
}
defer rows.Close()
var operations []*model.Operation
var statuses []TrustlogStatus
for rows.Next() {
var op model.Operation
var reqHash, respHash, clientIP, serverIP, statusStr sql.NullString
var createdAt time.Time
err := rows.Scan(
&op.OpID, &op.OpActor, &op.Doid, &op.ProducerID,
&reqHash, &respHash,
&op.OpSource, &op.OpCode, &op.DoPrefix, &op.DoRepository,
&clientIP, &serverIP, &statusStr, &op.Timestamp, &createdAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan operation row: %w", err)
}
// 处理可空字段
if reqHash.Valid {
op.RequestBodyHash = &reqHash.String
}
if respHash.Valid {
op.ResponseBodyHash = &respHash.String
}
if clientIP.Valid {
op.ClientIP = &clientIP.String
}
if serverIP.Valid {
op.ServerIP = &serverIP.String
}
operations = append(operations, &op)
statuses = append(statuses, TrustlogStatus(statusStr.String))
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating operation rows: %w", err)
}
// 计算总页数
totalPages := int(total) / pageSize
if int(total)%pageSize > 0 {
totalPages++
}
return &OperationQueryResult{
Operations: operations,
Statuses: statuses,
Total: total,
PageSize: pageSize,
PageNumber: pageNumber,
TotalPages: totalPages,
}, nil
}
// Count 统计符合条件的记录数
func (r *operationRepository) Count(ctx context.Context, req *OperationQueryRequest) (int64, error) {
if req == nil {
return 0, fmt.Errorf("query request cannot be nil")
}
// 构建 WHERE 子句
var conditions []string
var args []interface{}
argIndex := 1
if req.OpID != nil && *req.OpID != "" {
conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex))
args = append(args, *req.OpID)
argIndex++
}
if req.OpSource != nil && *req.OpSource != "" {
conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex))
args = append(args, *req.OpSource)
argIndex++
}
if req.OpCode != nil {
conditions = append(conditions, fmt.Sprintf("op_code = $%d", argIndex))
args = append(args, *req.OpCode)
argIndex++
}
if req.Doid != nil && *req.Doid != "" {
conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex))
args = append(args, "%"+*req.Doid+"%")
argIndex++
}
if req.ProducerID != nil && *req.ProducerID != "" {
conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex))
args = append(args, *req.ProducerID)
argIndex++
}
if req.OpActor != nil && *req.OpActor != "" {
conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex))
args = append(args, *req.OpActor)
argIndex++
}
if req.DoPrefix != nil && *req.DoPrefix != "" {
conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex))
args = append(args, "%"+*req.DoPrefix+"%")
argIndex++
}
if req.DoRepository != nil && *req.DoRepository != "" {
conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex))
args = append(args, *req.DoRepository)
argIndex++
}
if req.TrustlogStatus != nil {
conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex))
args = append(args, string(*req.TrustlogStatus))
argIndex++
}
if req.ClientIP != nil && *req.ClientIP != "" {
conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex))
args = append(args, *req.ClientIP)
argIndex++
}
if req.ServerIP != nil && *req.ServerIP != "" {
conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex))
args = append(args, *req.ServerIP)
argIndex++
}
if req.TimeFrom != nil {
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
args = append(args, *req.TimeFrom)
argIndex++
}
if req.TimeTo != nil {
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
args = append(args, *req.TimeTo)
argIndex++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0])
for i := 1; i < len(conditions); i++ {
whereClause += " AND " + conditions[i]
}
}
query := r.convertPlaceholders(fmt.Sprintf(`
SELECT COUNT(*) FROM operation %s
`, whereClause))
var count int64
err := r.db.QueryRowContext(ctx, query, args...).Scan(&count)
if err != nil {
r.logger.ErrorContext(ctx, "failed to count operations",
"error", err,
)
return 0, fmt.Errorf("failed to count operations: %w", err)
}
return count, nil
}
// cursorRepository 游标仓储实现
type cursorRepository struct {
db *sql.DB
logger logger.Logger
driverName string
}
// NewCursorRepository 创建游标仓储
func NewCursorRepository(db *sql.DB, log logger.Logger) CursorRepository {
driverName := detectDriverName(db)
return &cursorRepository{
db: db,
logger: log,
driverName: driverName,
}
}
// convertPlaceholders 将 ? 占位符转换为适合数据库的占位符
func (r *cursorRepository) convertPlaceholders(query string) string {
return convertPlaceholdersForDriver(query, r.driverName)
}
// GetCursor 获取游标值Key-Value 模式)
func (r *cursorRepository) GetCursor(ctx context.Context, cursorKey string) (string, error) {
query := r.convertPlaceholders(`SELECT cursor_value FROM trustlog_cursor WHERE cursor_key = ?`)
var cursorValue string
err := r.db.QueryRowContext(ctx, query, cursorKey).Scan(&cursorValue)
if err == sql.ErrNoRows {
r.logger.DebugContext(ctx, "cursor not found",
"cursorKey", cursorKey,
)
return "", nil
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to get cursor",
"cursorKey", cursorKey,
"error", err,
)
return "", fmt.Errorf("failed to get cursor: %w", err)
}
return cursorValue, nil
}
// UpdateCursor 更新游标值
func (r *cursorRepository) UpdateCursor(ctx context.Context, cursorKey string, cursorValue string) error {
return r.UpdateCursorTx(ctx, nil, cursorKey, cursorValue)
}
// UpdateCursorTx 在事务中更新游标值(使用 UPSERT
func (r *cursorRepository) UpdateCursorTx(ctx context.Context, tx *sql.Tx, cursorKey string, cursorValue string) error {
// 使用 UPSERT 语法(适配不同数据库)
query := r.convertPlaceholders(`
INSERT INTO trustlog_cursor (cursor_key, cursor_value, last_updated_at)
VALUES (?, ?, ?)
ON CONFLICT (cursor_key) DO UPDATE SET
cursor_value = excluded.cursor_value,
last_updated_at = excluded.last_updated_at
`)
var err error
now := time.Now()
if tx != nil {
_, err = tx.ExecContext(ctx, query, cursorKey, cursorValue, now)
} else {
_, err = r.db.ExecContext(ctx, query, cursorKey, cursorValue, now)
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to update cursor",
"cursorKey", cursorKey,
"error", err,
)
return fmt.Errorf("failed to update cursor: %w", err)
}
r.logger.DebugContext(ctx, "cursor updated",
"cursorKey", cursorKey,
"cursorValue", cursorValue,
)
return nil
}
// InitCursor 初始化游标(如果不存在)
func (r *cursorRepository) InitCursor(ctx context.Context, cursorKey string, initialValue string) error {
// 使用简单的 UPSERT如果冲突则更新为新值
// 这样可以确保 cursor 总是基于最新的数据库状态初始化
query := r.convertPlaceholders(`
INSERT INTO trustlog_cursor (cursor_key, cursor_value, last_updated_at)
VALUES (?, ?, ?)
ON CONFLICT (cursor_key)
DO UPDATE SET
cursor_value = EXCLUDED.cursor_value,
last_updated_at = EXCLUDED.last_updated_at
`)
now := time.Now()
_, err := r.db.ExecContext(ctx, query, cursorKey, initialValue, now)
if err != nil {
r.logger.ErrorContext(ctx, "failed to init cursor",
"cursorKey", cursorKey,
"error", err,
)
return fmt.Errorf("failed to init cursor: %w", err)
}
r.logger.DebugContext(ctx, "cursor initialized",
"cursorKey", cursorKey,
"initialValue", initialValue,
)
return nil
}
// retryRepository 重试仓储实现
type retryRepository struct {
db *sql.DB
logger logger.Logger
driverName string
}
// NewRetryRepository 创建重试仓储
func NewRetryRepository(db *sql.DB, log logger.Logger) RetryRepository {
driverName := detectDriverName(db)
return &retryRepository{
db: db,
logger: log,
driverName: driverName,
}
}
// convertPlaceholders 将 ? 占位符转换为适合数据库的占位符
func (r *retryRepository) convertPlaceholders(query string) string {
return convertPlaceholdersForDriver(query, r.driverName)
}
func (r *retryRepository) AddRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error {
return r.AddRetryTx(ctx, nil, opID, errorMsg, nextRetryAt)
}
func (r *retryRepository) AddRetryTx(ctx context.Context, tx *sql.Tx, opID string, errorMsg string, nextRetryAt time.Time) error {
query := r.convertPlaceholders(`
INSERT INTO trustlog_retry (op_id, retry_count, retry_status, error_message, next_retry_at, updated_at)
VALUES (?, 0, ?, ?, ?, ?)
`)
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, opID, string(RetryStatusPending), errorMsg, nextRetryAt, time.Now())
} else {
_, err = r.db.ExecContext(ctx, query, opID, string(RetryStatusPending), errorMsg, nextRetryAt, time.Now())
}
if err != nil {
r.logger.ErrorContext(ctx, "failed to add retry record",
"opID", opID,
"error", err,
)
return fmt.Errorf("failed to add retry record: %w", err)
}
r.logger.DebugContext(ctx, "retry record added",
"opID", opID,
"nextRetryAt", nextRetryAt,
)
return nil
}
func (r *retryRepository) IncrementRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error {
query := r.convertPlaceholders(`
UPDATE trustlog_retry
SET retry_count = retry_count + 1,
retry_status = ?,
last_retry_at = ?,
next_retry_at = ?,
error_message = ?,
updated_at = ?
WHERE op_id = ?
`)
_, err := r.db.ExecContext(ctx, query,
string(RetryStatusRetrying),
time.Now(),
nextRetryAt,
errorMsg,
time.Now(),
opID,
)
if err != nil {
r.logger.ErrorContext(ctx, "failed to increment retry",
"opID", opID,
"error", err,
)
return fmt.Errorf("failed to increment retry: %w", err)
}
r.logger.DebugContext(ctx, "retry incremented",
"opID", opID,
"nextRetryAt", nextRetryAt,
)
return nil
}
func (r *retryRepository) MarkAsDeadLetter(ctx context.Context, opID string, errorMsg string) error {
query := r.convertPlaceholders(`
UPDATE trustlog_retry
SET retry_status = ?,
error_message = ?,
updated_at = ?
WHERE op_id = ?
`)
_, err := r.db.ExecContext(ctx, query,
string(RetryStatusDeadLetter),
errorMsg,
time.Now(),
opID,
)
if err != nil {
r.logger.ErrorContext(ctx, "failed to mark as dead letter",
"opID", opID,
"error", err,
)
return fmt.Errorf("failed to mark as dead letter: %w", err)
}
r.logger.WarnContext(ctx, "operation marked as dead letter",
"opID", opID,
"error", errorMsg,
)
return nil
}
func (r *retryRepository) FindPendingRetries(ctx context.Context, limit int) ([]RetryRecord, error) {
query := r.convertPlaceholders(`
SELECT
op_id, retry_count, retry_status,
last_retry_at, next_retry_at, error_message,
created_at, updated_at
FROM trustlog_retry
WHERE retry_status IN (?, ?) AND next_retry_at <= ?
ORDER BY next_retry_at ASC
LIMIT ?
`)
rows, err := r.db.QueryContext(ctx, query,
string(RetryStatusPending),
string(RetryStatusRetrying),
time.Now(),
limit,
)
if err != nil {
r.logger.ErrorContext(ctx, "failed to find pending retries",
"error", err,
)
return nil, fmt.Errorf("failed to find pending retries: %w", err)
}
defer rows.Close()
var records []RetryRecord
for rows.Next() {
var record RetryRecord
var lastRetry, nextRetry sql.NullTime
err := rows.Scan(
&record.OpID,
&record.RetryCount,
&record.RetryStatus,
&lastRetry,
&nextRetry,
&record.ErrorMessage,
&record.CreatedAt,
&record.UpdatedAt,
)
if err != nil {
r.logger.ErrorContext(ctx, "failed to scan retry record",
"error", err,
)
return nil, fmt.Errorf("failed to scan retry record: %w", err)
}
if lastRetry.Valid {
record.LastRetryAt = &lastRetry.Time
}
if nextRetry.Valid {
record.NextRetryAt = &nextRetry.Time
}
records = append(records, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating retry records: %w", err)
}
return records, nil
}
func (r *retryRepository) DeleteRetry(ctx context.Context, opID string) error {
query := r.convertPlaceholders(`DELETE FROM trustlog_retry WHERE op_id = ?`)
_, err := r.db.ExecContext(ctx, query, opID)
if err != nil {
r.logger.ErrorContext(ctx, "failed to delete retry record",
"opID", opID,
"error", err,
)
return fmt.Errorf("failed to delete retry record: %w", err)
}
r.logger.DebugContext(ctx, "retry record deleted",
"opID", opID,
)
return nil
}