package persistence import ( "context" "database/sql" "fmt" "time" "go.yandata.net/iod/iod/go-trustlog/api/logger" "go.yandata.net/iod/iod/go-trustlog/api/model" ) // OperationQueryRequest 操作记录查询请求 type OperationQueryRequest struct { // OpID 操作ID(精确匹配) OpID *string // OpSource 操作来源(精确匹配) OpSource *string // OpType 操作类型(精确匹配) OpType *string // Doid 数字对象标识符(支持 LIKE 模糊查询) Doid *string // ProducerID 生产者ID(精确匹配) ProducerID *string // OpActor 操作执行者(精确匹配) OpActor *string // DoPrefix DO前缀(支持 LIKE 模糊查询) DoPrefix *string // DoRepository DO仓库(精确匹配) DoRepository *string // TrustlogStatus 存证状态(精确匹配) TrustlogStatus *TrustlogStatus // ClientIP 客户端IP(精确匹配) ClientIP *string // ServerIP 服务端IP(精确匹配) ServerIP *string // TimeFrom 时间范围查询-开始时间(闭区间) TimeFrom *time.Time // TimeTo 时间范围查询-结束时间(闭区间) TimeTo *time.Time // PageSize 每页数量(默认20,最大1000) PageSize int // PageNumber 页码(从1开始) PageNumber int // OrderBy 排序字段(created_at, timestamp, op_id) OrderBy string // OrderDesc 是否降序排序(默认 false 升序) OrderDesc bool } // OperationQueryResult 操作记录查询结果 type OperationQueryResult struct { // Operations 操作记录列表 Operations []*model.Operation // Statuses 对应的存证状态列表 Statuses []TrustlogStatus // Total 总记录数 Total int64 // PageSize 每页数量 PageSize int // PageNumber 当前页码 PageNumber int // TotalPages 总页数 TotalPages int } // OperationRepository 操作记录数据库仓储接口 type OperationRepository interface { // Save 保存操作记录到数据库 Save(ctx context.Context, op *model.Operation, status TrustlogStatus) error // SaveTx 在事务中保存操作记录 SaveTx(ctx context.Context, tx *sql.Tx, op *model.Operation, status TrustlogStatus) error // UpdateStatus 更新操作记录的存证状态 UpdateStatus(ctx context.Context, opID string, status TrustlogStatus) error // UpdateStatusTx 在事务中更新操作记录的存证状态 UpdateStatusTx(ctx context.Context, tx *sql.Tx, opID string, status TrustlogStatus) error // FindByID 根据 OpID 查询操作记录 FindByID(ctx context.Context, opID string) (*model.Operation, TrustlogStatus, error) // FindUntrustlogged 查询未存证的操作记录(用于重试机制) FindUntrustlogged(ctx context.Context, limit int) ([]*model.Operation, error) // FindUntrustloggedWithLock 查找未存证的操作(支持集群并发安全) // 使用 SELECT FOR UPDATE SKIP LOCKED 确保多个 worker 不会处理相同的记录 // 返回: operations, opIDs, error FindUntrustloggedWithLock(ctx context.Context, tx *sql.Tx, limit int) ([]*model.Operation, []string, error) // UpdateStatusWithCAS 使用 CAS (Compare-And-Set) 更新状态 // 只有当前状态匹配 expectedStatus 时才会更新 // 返回: updated (是否更新成功), error UpdateStatusWithCAS(ctx context.Context, tx *sql.Tx, opID string, expectedStatus, newStatus TrustlogStatus) (bool, error) // Query 根据条件查询操作记录(支持分页、筛选、排序) Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) // Count 统计符合条件的记录数 Count(ctx context.Context, req *OperationQueryRequest) (int64, error) } // CursorRepository 游标仓储接口(Key-Value 模式) type CursorRepository interface { // GetCursor 获取游标值 GetCursor(ctx context.Context, cursorKey string) (string, error) // UpdateCursor 更新游标值 UpdateCursor(ctx context.Context, cursorKey string, cursorValue string) error // UpdateCursorTx 在事务中更新游标值 UpdateCursorTx(ctx context.Context, tx *sql.Tx, cursorKey string, cursorValue string) error // InitCursor 初始化游标(如果不存在) InitCursor(ctx context.Context, cursorKey string, initialValue string) error } // RetryRepository 重试仓储接口 type RetryRepository interface { // AddRetry 添加重试记录 AddRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error // AddRetryTx 在事务中添加重试记录 AddRetryTx(ctx context.Context, tx *sql.Tx, opID string, errorMsg string, nextRetryAt time.Time) error // IncrementRetry 增加重试次数 IncrementRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error // MarkAsDeadLetter 标记为死信 MarkAsDeadLetter(ctx context.Context, opID string, errorMsg string) error // FindPendingRetries 查找待重试的记录 FindPendingRetries(ctx context.Context, limit int) ([]RetryRecord, error) // DeleteRetry 删除重试记录(成功后清理) DeleteRetry(ctx context.Context, opID string) error } // RetryRecord 重试记录 type RetryRecord struct { OpID string RetryCount int RetryStatus RetryStatus LastRetryAt *time.Time NextRetryAt *time.Time ErrorMessage string CreatedAt time.Time UpdatedAt time.Time } // operationRepository 操作记录仓储实现 type operationRepository struct { db *sql.DB logger logger.Logger driverName string } // detectDriverName 检测数据库驱动名 func detectDriverName(db *sql.DB) string { if db == nil { return "sqlite3" } // 尝试执行 PostgreSQL 特有的查询 var version string err := db.QueryRow("SELECT version()").Scan(&version) if err == nil && len(version) >= 10 && version[:10] == "PostgreSQL" { return "postgres" } return "sqlite3" // 默认 } // convertPlaceholdersForDriver 将 ? 占位符转换为适合数据库的占位符 func convertPlaceholdersForDriver(query, driverName string) string { if driverName == "postgres" { // PostgreSQL 使用 $1, $2, $3... count := 1 result := "" for i := 0; i < len(query); i++ { if query[i] == '?' { result += fmt.Sprintf("$%d", count) count++ } else { result += string(query[i]) } } return result } // 其他数据库(SQLite, MySQL)使用 ? return query } // NewOperationRepository 创建操作记录仓储 func NewOperationRepository(db *sql.DB, log logger.Logger) OperationRepository { driverName := detectDriverName(db) return &operationRepository{ db: db, logger: log, driverName: driverName, } } // convertPlaceholders 将 ? 占位符转换为适合数据库的占位符 func (r *operationRepository) convertPlaceholders(query string) string { return convertPlaceholdersForDriver(query, r.driverName) } func (r *operationRepository) Save(ctx context.Context, op *model.Operation, status TrustlogStatus) error { return r.SaveTx(ctx, nil, op, status) } func (r *operationRepository) SaveTx(ctx context.Context, tx *sql.Tx, op *model.Operation, status TrustlogStatus) error { query := r.convertPlaceholders(` INSERT INTO operation ( op_id, op_actor, doid, producer_id, request_body_hash, response_body_hash, op_source, op_type, do_prefix, do_repository, client_ip, server_ip, trustlog_status, timestamp ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `) var reqHash, respHash, clientIP, serverIP sql.NullString if op.RequestBodyHash != nil { reqHash = sql.NullString{String: *op.RequestBodyHash, Valid: true} } if op.ResponseBodyHash != nil { respHash = sql.NullString{String: *op.ResponseBodyHash, Valid: true} } if op.ClientIP != nil { clientIP = sql.NullString{String: *op.ClientIP, Valid: true} } if op.ServerIP != nil { serverIP = sql.NullString{String: *op.ServerIP, Valid: true} } args := []interface{}{ op.OpID, op.OpActor, op.Doid, op.ProducerID, reqHash, respHash, string(op.OpSource), string(op.OpType), op.DoPrefix, op.DoRepository, clientIP, serverIP, string(status), op.Timestamp, } var err error if tx != nil { _, err = tx.ExecContext(ctx, query, args...) } else { _, err = r.db.ExecContext(ctx, query, args...) } if err != nil { r.logger.ErrorContext(ctx, "failed to save operation", "opID", op.OpID, "error", err, ) return fmt.Errorf("failed to save operation: %w", err) } r.logger.DebugContext(ctx, "operation saved to database", "opID", op.OpID, "status", status, ) return nil } func (r *operationRepository) UpdateStatus(ctx context.Context, opID string, status TrustlogStatus) error { return r.UpdateStatusTx(ctx, nil, opID, status) } func (r *operationRepository) UpdateStatusTx(ctx context.Context, tx *sql.Tx, opID string, status TrustlogStatus) error { query := r.convertPlaceholders(`UPDATE operation SET trustlog_status = ? WHERE op_id = ?`) var err error if tx != nil { _, err = tx.ExecContext(ctx, query, string(status), opID) } else { _, err = r.db.ExecContext(ctx, query, string(status), opID) } if err != nil { r.logger.ErrorContext(ctx, "failed to update operation status", "opID", opID, "status", status, "error", err, ) return fmt.Errorf("failed to update operation status: %w", err) } r.logger.DebugContext(ctx, "operation status updated", "opID", opID, "status", status, ) return nil } func (r *operationRepository) FindByID(ctx context.Context, opID string) (*model.Operation, TrustlogStatus, error) { query := r.convertPlaceholders(` SELECT op_id, op_actor, doid, producer_id, request_body_hash, response_body_hash, op_source, op_type, do_prefix, do_repository, client_ip, server_ip, trustlog_status, timestamp FROM operation WHERE op_id = ? `) var op model.Operation var statusStr string var reqHash, respHash, clientIP, serverIP sql.NullString err := r.db.QueryRowContext(ctx, query, opID).Scan( &op.OpID, &op.OpActor, &op.Doid, &op.ProducerID, &reqHash, &respHash, &op.OpSource, &op.OpType, &op.DoPrefix, &op.DoRepository, &clientIP, &serverIP, &statusStr, &op.Timestamp, ) if err == sql.ErrNoRows { return nil, "", fmt.Errorf("operation not found: %s", opID) } if err != nil { r.logger.ErrorContext(ctx, "failed to find operation", "opID", opID, "error", err, ) return nil, "", fmt.Errorf("failed to find operation: %w", err) } if reqHash.Valid { op.RequestBodyHash = &reqHash.String } if respHash.Valid { op.ResponseBodyHash = &respHash.String } if clientIP.Valid { op.ClientIP = &clientIP.String } if serverIP.Valid { op.ServerIP = &serverIP.String } return &op, TrustlogStatus(statusStr), nil } // FindUntrustloggedWithLock 查找未存证的操作(支持集群并发安全) // 使用 SELECT FOR UPDATE SKIP LOCKED 确保多个 worker 不会处理相同的记录 func (r *operationRepository) FindUntrustloggedWithLock(ctx context.Context, tx *sql.Tx, limit int) ([]*model.Operation, []string, error) { // 使用 FOR UPDATE SKIP LOCKED 锁定记录 // SKIP LOCKED: 跳过已被其他事务锁定的行,避免等待 query := r.convertPlaceholders(` SELECT op_id, op_actor, doid, producer_id, request_body_hash, response_body_hash, op_source, op_type, do_prefix, do_repository, client_ip, server_ip, timestamp FROM operation WHERE trustlog_status = ? ORDER BY timestamp ASC LIMIT ? FOR UPDATE SKIP LOCKED `) var rows *sql.Rows var err error if tx != nil { rows, err = tx.QueryContext(ctx, query, string(StatusNotTrustlogged), limit) } else { rows, err = r.db.QueryContext(ctx, query, string(StatusNotTrustlogged), limit) } if err != nil { r.logger.ErrorContext(ctx, "failed to find untrustlogged operations with lock", "error", err, ) return nil, nil, fmt.Errorf("failed to find untrustlogged operations: %w", err) } defer rows.Close() var operations []*model.Operation var opIDs []string for rows.Next() { var op model.Operation var reqHash, respHash, clientIP, serverIP sql.NullString err := rows.Scan( &op.OpID, &op.OpActor, &op.Doid, &op.ProducerID, &reqHash, &respHash, &op.OpSource, &op.OpType, &op.DoPrefix, &op.DoRepository, &clientIP, &serverIP, &op.Timestamp, ) if err != nil { r.logger.ErrorContext(ctx, "failed to scan operation", "error", err, ) continue } if reqHash.Valid { op.RequestBodyHash = &reqHash.String } if respHash.Valid { op.ResponseBodyHash = &respHash.String } if clientIP.Valid { op.ClientIP = &clientIP.String } if serverIP.Valid { op.ServerIP = &serverIP.String } operations = append(operations, &op) opIDs = append(opIDs, op.OpID) } if err := rows.Err(); err != nil { r.logger.ErrorContext(ctx, "error iterating rows", "error", err, ) return nil, nil, fmt.Errorf("error iterating rows: %w", err) } return operations, opIDs, nil } // UpdateStatusWithCAS 使用 CAS (Compare-And-Set) 更新状态 // 只有当前状态匹配 expectedStatus 时才会更新,确保并发安全 func (r *operationRepository) UpdateStatusWithCAS(ctx context.Context, tx *sql.Tx, opID string, expectedStatus, newStatus TrustlogStatus) (bool, error) { query := r.convertPlaceholders(` UPDATE operation SET trustlog_status = ? WHERE op_id = ? AND trustlog_status = ? `) var result sql.Result var err error if tx != nil { result, err = tx.ExecContext(ctx, query, string(newStatus), opID, string(expectedStatus)) } else { result, err = r.db.ExecContext(ctx, query, string(newStatus), opID, string(expectedStatus)) } if err != nil { r.logger.ErrorContext(ctx, "failed to update operation status with CAS", "opID", opID, "expectedStatus", expectedStatus, "newStatus", newStatus, "error", err, ) return false, fmt.Errorf("failed to update operation status: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return false, fmt.Errorf("failed to get rows affected: %w", err) } // 如果影响行数为 0,说明状态已被其他 worker 修改 if rowsAffected == 0 { r.logger.WarnContext(ctx, "CAS update failed: status already changed by another worker", "opID", opID, "expectedStatus", expectedStatus, ) return false, nil } r.logger.DebugContext(ctx, "operation status updated with CAS", "opID", opID, "expectedStatus", expectedStatus, "newStatus", newStatus, ) return true, nil } func (r *operationRepository) FindUntrustlogged(ctx context.Context, limit int) ([]*model.Operation, error) { query := r.convertPlaceholders(` SELECT op_id, op_actor, doid, producer_id, request_body_hash, response_body_hash, op_source, op_type, do_prefix, do_repository, client_ip, server_ip, timestamp FROM operation WHERE trustlog_status = ? ORDER BY timestamp ASC LIMIT ? `) rows, err := r.db.QueryContext(ctx, query, string(StatusNotTrustlogged), limit) if err != nil { r.logger.ErrorContext(ctx, "failed to find untrustlogged operations", "error", err, ) return nil, fmt.Errorf("failed to find untrustlogged operations: %w", err) } defer rows.Close() var operations []*model.Operation for rows.Next() { var op model.Operation var reqHash, respHash, clientIP, serverIP sql.NullString err := rows.Scan( &op.OpID, &op.OpActor, &op.Doid, &op.ProducerID, &reqHash, &respHash, &op.OpSource, &op.OpType, &op.DoPrefix, &op.DoRepository, &clientIP, &serverIP, &op.Timestamp, ) if err != nil { r.logger.ErrorContext(ctx, "failed to scan operation row", "error", err, ) return nil, fmt.Errorf("failed to scan operation row: %w", err) } if reqHash.Valid { op.RequestBodyHash = &reqHash.String } if respHash.Valid { op.ResponseBodyHash = &respHash.String } if clientIP.Valid { op.ClientIP = &clientIP.String } if serverIP.Valid { op.ServerIP = &serverIP.String } operations = append(operations, &op) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating operation rows: %w", err) } return operations, nil } // Query 根据条件查询操作记录(支持分页、筛选、排序) func (r *operationRepository) Query(ctx context.Context, req *OperationQueryRequest) (*OperationQueryResult, error) { if req == nil { return nil, fmt.Errorf("query request cannot be nil") } // 设置默认值 pageSize := req.PageSize if pageSize <= 0 { pageSize = 20 } if pageSize > 1000 { pageSize = 1000 } pageNumber := req.PageNumber if pageNumber <= 0 { pageNumber = 1 } orderBy := req.OrderBy if orderBy == "" { orderBy = "created_at" } // 防止 SQL 注入,只允许特定字段排序 switch orderBy { case "created_at", "timestamp", "op_id": // 允许 default: orderBy = "created_at" } orderDirection := "ASC" if req.OrderDesc { orderDirection = "DESC" } // 构建 WHERE 子句 var conditions []string var args []interface{} argIndex := 1 if req.OpID != nil && *req.OpID != "" { conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex)) args = append(args, *req.OpID) argIndex++ } if req.OpSource != nil && *req.OpSource != "" { conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex)) args = append(args, *req.OpSource) argIndex++ } if req.OpType != nil && *req.OpType != "" { conditions = append(conditions, fmt.Sprintf("op_type = $%d", argIndex)) args = append(args, *req.OpType) argIndex++ } if req.Doid != nil && *req.Doid != "" { conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex)) args = append(args, "%"+*req.Doid+"%") argIndex++ } if req.ProducerID != nil && *req.ProducerID != "" { conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex)) args = append(args, *req.ProducerID) argIndex++ } if req.OpActor != nil && *req.OpActor != "" { conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex)) args = append(args, *req.OpActor) argIndex++ } if req.DoPrefix != nil && *req.DoPrefix != "" { conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex)) args = append(args, "%"+*req.DoPrefix+"%") argIndex++ } if req.DoRepository != nil && *req.DoRepository != "" { conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex)) args = append(args, *req.DoRepository) argIndex++ } if req.TrustlogStatus != nil { conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex)) args = append(args, string(*req.TrustlogStatus)) argIndex++ } if req.ClientIP != nil && *req.ClientIP != "" { conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex)) args = append(args, *req.ClientIP) argIndex++ } if req.ServerIP != nil && *req.ServerIP != "" { conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex)) args = append(args, *req.ServerIP) argIndex++ } if req.TimeFrom != nil { conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) args = append(args, *req.TimeFrom) argIndex++ } if req.TimeTo != nil { conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) args = append(args, *req.TimeTo) argIndex++ } whereClause := "" if len(conditions) > 0 { whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0]) for i := 1; i < len(conditions); i++ { whereClause += " AND " + conditions[i] } } // 先查询总数 total, err := r.Count(ctx, req) if err != nil { return nil, err } // 查询数据 offset := (pageNumber - 1) * pageSize query := r.convertPlaceholders(fmt.Sprintf(` SELECT op_id, op_actor, doid, producer_id, request_body_hash, response_body_hash, op_source, op_type, do_prefix, do_repository, client_ip, server_ip, trustlog_status, timestamp, created_at FROM operation %s ORDER BY %s %s LIMIT $%d OFFSET $%d `, whereClause, orderBy, orderDirection, argIndex, argIndex+1)) args = append(args, pageSize, offset) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { r.logger.ErrorContext(ctx, "failed to query operations", "error", err, ) return nil, fmt.Errorf("failed to query operations: %w", err) } defer rows.Close() var operations []*model.Operation var statuses []TrustlogStatus for rows.Next() { var op model.Operation var reqHash, respHash, clientIP, serverIP, statusStr sql.NullString var createdAt time.Time err := rows.Scan( &op.OpID, &op.OpActor, &op.Doid, &op.ProducerID, &reqHash, &respHash, &op.OpSource, &op.OpType, &op.DoPrefix, &op.DoRepository, &clientIP, &serverIP, &statusStr, &op.Timestamp, &createdAt, ) if err != nil { return nil, fmt.Errorf("failed to scan operation row: %w", err) } // 处理可空字段 if reqHash.Valid { op.RequestBodyHash = &reqHash.String } if respHash.Valid { op.ResponseBodyHash = &respHash.String } if clientIP.Valid { op.ClientIP = &clientIP.String } if serverIP.Valid { op.ServerIP = &serverIP.String } operations = append(operations, &op) statuses = append(statuses, TrustlogStatus(statusStr.String)) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating operation rows: %w", err) } // 计算总页数 totalPages := int(total) / pageSize if int(total)%pageSize > 0 { totalPages++ } return &OperationQueryResult{ Operations: operations, Statuses: statuses, Total: total, PageSize: pageSize, PageNumber: pageNumber, TotalPages: totalPages, }, nil } // Count 统计符合条件的记录数 func (r *operationRepository) Count(ctx context.Context, req *OperationQueryRequest) (int64, error) { if req == nil { return 0, fmt.Errorf("query request cannot be nil") } // 构建 WHERE 子句 var conditions []string var args []interface{} argIndex := 1 if req.OpID != nil && *req.OpID != "" { conditions = append(conditions, fmt.Sprintf("op_id = $%d", argIndex)) args = append(args, *req.OpID) argIndex++ } if req.OpSource != nil && *req.OpSource != "" { conditions = append(conditions, fmt.Sprintf("op_source = $%d", argIndex)) args = append(args, *req.OpSource) argIndex++ } if req.OpType != nil && *req.OpType != "" { conditions = append(conditions, fmt.Sprintf("op_type = $%d", argIndex)) args = append(args, *req.OpType) argIndex++ } if req.Doid != nil && *req.Doid != "" { conditions = append(conditions, fmt.Sprintf("doid LIKE $%d", argIndex)) args = append(args, "%"+*req.Doid+"%") argIndex++ } if req.ProducerID != nil && *req.ProducerID != "" { conditions = append(conditions, fmt.Sprintf("producer_id = $%d", argIndex)) args = append(args, *req.ProducerID) argIndex++ } if req.OpActor != nil && *req.OpActor != "" { conditions = append(conditions, fmt.Sprintf("op_actor = $%d", argIndex)) args = append(args, *req.OpActor) argIndex++ } if req.DoPrefix != nil && *req.DoPrefix != "" { conditions = append(conditions, fmt.Sprintf("do_prefix LIKE $%d", argIndex)) args = append(args, "%"+*req.DoPrefix+"%") argIndex++ } if req.DoRepository != nil && *req.DoRepository != "" { conditions = append(conditions, fmt.Sprintf("do_repository = $%d", argIndex)) args = append(args, *req.DoRepository) argIndex++ } if req.TrustlogStatus != nil { conditions = append(conditions, fmt.Sprintf("trustlog_status = $%d", argIndex)) args = append(args, string(*req.TrustlogStatus)) argIndex++ } if req.ClientIP != nil && *req.ClientIP != "" { conditions = append(conditions, fmt.Sprintf("client_ip = $%d", argIndex)) args = append(args, *req.ClientIP) argIndex++ } if req.ServerIP != nil && *req.ServerIP != "" { conditions = append(conditions, fmt.Sprintf("server_ip = $%d", argIndex)) args = append(args, *req.ServerIP) argIndex++ } if req.TimeFrom != nil { conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex)) args = append(args, *req.TimeFrom) argIndex++ } if req.TimeTo != nil { conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex)) args = append(args, *req.TimeTo) argIndex++ } whereClause := "" if len(conditions) > 0 { whereClause = "WHERE " + fmt.Sprintf("%s", conditions[0]) for i := 1; i < len(conditions); i++ { whereClause += " AND " + conditions[i] } } query := r.convertPlaceholders(fmt.Sprintf(` SELECT COUNT(*) FROM operation %s `, whereClause)) var count int64 err := r.db.QueryRowContext(ctx, query, args...).Scan(&count) if err != nil { r.logger.ErrorContext(ctx, "failed to count operations", "error", err, ) return 0, fmt.Errorf("failed to count operations: %w", err) } return count, nil } // cursorRepository 游标仓储实现 type cursorRepository struct { db *sql.DB logger logger.Logger driverName string } // NewCursorRepository 创建游标仓储 func NewCursorRepository(db *sql.DB, log logger.Logger) CursorRepository { driverName := detectDriverName(db) return &cursorRepository{ db: db, logger: log, driverName: driverName, } } // convertPlaceholders 将 ? 占位符转换为适合数据库的占位符 func (r *cursorRepository) convertPlaceholders(query string) string { return convertPlaceholdersForDriver(query, r.driverName) } // GetCursor 获取游标值(Key-Value 模式) func (r *cursorRepository) GetCursor(ctx context.Context, cursorKey string) (string, error) { query := r.convertPlaceholders(`SELECT cursor_value FROM trustlog_cursor WHERE cursor_key = ?`) var cursorValue string err := r.db.QueryRowContext(ctx, query, cursorKey).Scan(&cursorValue) if err == sql.ErrNoRows { r.logger.DebugContext(ctx, "cursor not found", "cursorKey", cursorKey, ) return "", nil } if err != nil { r.logger.ErrorContext(ctx, "failed to get cursor", "cursorKey", cursorKey, "error", err, ) return "", fmt.Errorf("failed to get cursor: %w", err) } return cursorValue, nil } // UpdateCursor 更新游标值 func (r *cursorRepository) UpdateCursor(ctx context.Context, cursorKey string, cursorValue string) error { return r.UpdateCursorTx(ctx, nil, cursorKey, cursorValue) } // UpdateCursorTx 在事务中更新游标值(使用 UPSERT) func (r *cursorRepository) UpdateCursorTx(ctx context.Context, tx *sql.Tx, cursorKey string, cursorValue string) error { // 使用 UPSERT 语法(适配不同数据库) query := r.convertPlaceholders(` INSERT INTO trustlog_cursor (cursor_key, cursor_value, last_updated_at) VALUES (?, ?, ?) ON CONFLICT (cursor_key) DO UPDATE SET cursor_value = excluded.cursor_value, last_updated_at = excluded.last_updated_at `) var err error now := time.Now() if tx != nil { _, err = tx.ExecContext(ctx, query, cursorKey, cursorValue, now) } else { _, err = r.db.ExecContext(ctx, query, cursorKey, cursorValue, now) } if err != nil { r.logger.ErrorContext(ctx, "failed to update cursor", "cursorKey", cursorKey, "error", err, ) return fmt.Errorf("failed to update cursor: %w", err) } r.logger.DebugContext(ctx, "cursor updated", "cursorKey", cursorKey, "cursorValue", cursorValue, ) return nil } // InitCursor 初始化游标(如果不存在) func (r *cursorRepository) InitCursor(ctx context.Context, cursorKey string, initialValue string) error { // 使用简单的 UPSERT:如果冲突则更新为新值 // 这样可以确保 cursor 总是基于最新的数据库状态初始化 query := r.convertPlaceholders(` INSERT INTO trustlog_cursor (cursor_key, cursor_value, last_updated_at) VALUES (?, ?, ?) ON CONFLICT (cursor_key) DO UPDATE SET cursor_value = EXCLUDED.cursor_value, last_updated_at = EXCLUDED.last_updated_at `) now := time.Now() _, err := r.db.ExecContext(ctx, query, cursorKey, initialValue, now) if err != nil { r.logger.ErrorContext(ctx, "failed to init cursor", "cursorKey", cursorKey, "error", err, ) return fmt.Errorf("failed to init cursor: %w", err) } r.logger.DebugContext(ctx, "cursor initialized", "cursorKey", cursorKey, "initialValue", initialValue, ) return nil } // retryRepository 重试仓储实现 type retryRepository struct { db *sql.DB logger logger.Logger driverName string } // NewRetryRepository 创建重试仓储 func NewRetryRepository(db *sql.DB, log logger.Logger) RetryRepository { driverName := detectDriverName(db) return &retryRepository{ db: db, logger: log, driverName: driverName, } } // convertPlaceholders 将 ? 占位符转换为适合数据库的占位符 func (r *retryRepository) convertPlaceholders(query string) string { return convertPlaceholdersForDriver(query, r.driverName) } func (r *retryRepository) AddRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error { return r.AddRetryTx(ctx, nil, opID, errorMsg, nextRetryAt) } func (r *retryRepository) AddRetryTx(ctx context.Context, tx *sql.Tx, opID string, errorMsg string, nextRetryAt time.Time) error { query := r.convertPlaceholders(` INSERT INTO trustlog_retry (op_id, retry_count, retry_status, error_message, next_retry_at, updated_at) VALUES (?, 0, ?, ?, ?, ?) `) var err error if tx != nil { _, err = tx.ExecContext(ctx, query, opID, string(RetryStatusPending), errorMsg, nextRetryAt, time.Now()) } else { _, err = r.db.ExecContext(ctx, query, opID, string(RetryStatusPending), errorMsg, nextRetryAt, time.Now()) } if err != nil { r.logger.ErrorContext(ctx, "failed to add retry record", "opID", opID, "error", err, ) return fmt.Errorf("failed to add retry record: %w", err) } r.logger.DebugContext(ctx, "retry record added", "opID", opID, "nextRetryAt", nextRetryAt, ) return nil } func (r *retryRepository) IncrementRetry(ctx context.Context, opID string, errorMsg string, nextRetryAt time.Time) error { query := r.convertPlaceholders(` UPDATE trustlog_retry SET retry_count = retry_count + 1, retry_status = ?, last_retry_at = ?, next_retry_at = ?, error_message = ?, updated_at = ? WHERE op_id = ? `) _, err := r.db.ExecContext(ctx, query, string(RetryStatusRetrying), time.Now(), nextRetryAt, errorMsg, time.Now(), opID, ) if err != nil { r.logger.ErrorContext(ctx, "failed to increment retry", "opID", opID, "error", err, ) return fmt.Errorf("failed to increment retry: %w", err) } r.logger.DebugContext(ctx, "retry incremented", "opID", opID, "nextRetryAt", nextRetryAt, ) return nil } func (r *retryRepository) MarkAsDeadLetter(ctx context.Context, opID string, errorMsg string) error { query := r.convertPlaceholders(` UPDATE trustlog_retry SET retry_status = ?, error_message = ?, updated_at = ? WHERE op_id = ? `) _, err := r.db.ExecContext(ctx, query, string(RetryStatusDeadLetter), errorMsg, time.Now(), opID, ) if err != nil { r.logger.ErrorContext(ctx, "failed to mark as dead letter", "opID", opID, "error", err, ) return fmt.Errorf("failed to mark as dead letter: %w", err) } r.logger.WarnContext(ctx, "operation marked as dead letter", "opID", opID, "error", errorMsg, ) return nil } func (r *retryRepository) FindPendingRetries(ctx context.Context, limit int) ([]RetryRecord, error) { query := r.convertPlaceholders(` SELECT op_id, retry_count, retry_status, last_retry_at, next_retry_at, error_message, created_at, updated_at FROM trustlog_retry WHERE retry_status IN (?, ?) AND next_retry_at <= ? ORDER BY next_retry_at ASC LIMIT ? `) rows, err := r.db.QueryContext(ctx, query, string(RetryStatusPending), string(RetryStatusRetrying), time.Now(), limit, ) if err != nil { r.logger.ErrorContext(ctx, "failed to find pending retries", "error", err, ) return nil, fmt.Errorf("failed to find pending retries: %w", err) } defer rows.Close() var records []RetryRecord for rows.Next() { var record RetryRecord var lastRetry, nextRetry sql.NullTime err := rows.Scan( &record.OpID, &record.RetryCount, &record.RetryStatus, &lastRetry, &nextRetry, &record.ErrorMessage, &record.CreatedAt, &record.UpdatedAt, ) if err != nil { r.logger.ErrorContext(ctx, "failed to scan retry record", "error", err, ) return nil, fmt.Errorf("failed to scan retry record: %w", err) } if lastRetry.Valid { record.LastRetryAt = &lastRetry.Time } if nextRetry.Valid { record.NextRetryAt = &nextRetry.Time } records = append(records, record) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating retry records: %w", err) } return records, nil } func (r *retryRepository) DeleteRetry(ctx context.Context, opID string) error { query := r.convertPlaceholders(`DELETE FROM trustlog_retry WHERE op_id = ?`) _, err := r.db.ExecContext(ctx, query, opID) if err != nil { r.logger.ErrorContext(ctx, "failed to delete retry record", "opID", opID, "error", err, ) return fmt.Errorf("failed to delete retry record: %w", err) } r.logger.DebugContext(ctx, "retry record deleted", "opID", opID, ) return nil }