refactor: 重构trustlog-sdk目录结构到trustlog/go-trustlog
- 将所有trustlog-sdk文件移动到trustlog/go-trustlog/目录 - 更新README中所有import路径从trustlog-sdk改为go-trustlog - 更新cookiecutter配置文件中的项目名称 - 更新根目录.lefthook.yml以引用新位置的配置 - 添加go.sum文件到版本控制 - 删除过时的示例文件 这次重构与trustlog-server保持一致的目录结构, 为未来支持多语言SDK(Python、Java等)预留空间。
This commit is contained in:
56
.dockerignore
Normal file
56
.dockerignore
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Git
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
.github
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
**/__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
.pytest_cache/
|
||||||
|
..mypy_cache/
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv
|
||||||
|
.venv
|
||||||
|
.DS_Store
|
||||||
|
.AppleDouble
|
||||||
|
.LSOverride
|
||||||
|
._*
|
||||||
|
|
||||||
|
# Temporary directories in the project
|
||||||
|
bin
|
||||||
|
tmp
|
||||||
|
|
||||||
|
./docker/
|
||||||
|
./coverage/
|
||||||
|
|
||||||
|
build-script.sh
|
||||||
|
codecov.yml
|
||||||
|
.editorconfig
|
||||||
|
.golangci.yml
|
||||||
|
.goreleaser.yml
|
||||||
|
.pre-commit-config.yaml
|
||||||
|
cookiecutter-config-file.yml
|
||||||
|
|
||||||
|
Makefile
|
||||||
|
LICENSE
|
||||||
|
README.md
|
||||||
|
CONTRIBUTING.md
|
||||||
|
SECURITY.md
|
||||||
|
CODE_OF_CONDUCT.md
|
||||||
23
.editorconfig
Normal file
23
.editorconfig
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Check http://editorconfig.org for more information
|
||||||
|
# This is the main config file for this project:
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
charset = utf-8
|
||||||
|
end_of_line = lf
|
||||||
|
insert_final_newline = true
|
||||||
|
indent_style = tab
|
||||||
|
indent_size = 4
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
|
||||||
|
[Makefile]
|
||||||
|
indent_style = tab # older versions of GNU Make do not work well with spaces
|
||||||
|
|
||||||
|
[*.{yaml,yml}]
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
[*.{md,rst}]
|
||||||
|
trim_trailing_whitespace = false
|
||||||
|
|
||||||
|
[*.{diff,patch}]
|
||||||
|
trim_trailing_whitespace = false
|
||||||
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Cache folders generated by IDE
|
||||||
|
.vscode/**
|
||||||
|
.idea/**
|
||||||
|
|
||||||
|
# Mac only
|
||||||
|
.DS_Store/**
|
||||||
|
|
||||||
|
# Ignore directories containing binaries generated and other stuff
|
||||||
|
tmp/**
|
||||||
|
bin/.gitignore
|
||||||
|
coverage/**
|
||||||
|
|
||||||
|
# Directories require atleast one file to be tracked by git.
|
||||||
|
!tmp/.gitkeep
|
||||||
|
!bin/.gitkeep
|
||||||
|
!coverage/.gitkeep
|
||||||
|
|
||||||
|
# Ignore config files from project root, keeping sample config files
|
||||||
|
/configs**
|
||||||
|
!/configs.**.sample
|
||||||
|
|
||||||
|
# Ignore all secret files — they'll be used for project sensitive data
|
||||||
|
**.secrets
|
||||||
|
**.session
|
||||||
|
**.secret
|
||||||
|
**.cookie
|
||||||
470
.golangci.yml
Normal file
470
.golangci.yml
Normal file
@@ -0,0 +1,470 @@
|
|||||||
|
# Based on https://gist.github.com/maratori/47a4d00457a92aa426dbd48a18776322
|
||||||
|
# https://gist.githubusercontent.com/maratori/47a4d00457a92aa426dbd48a18776322/raw/2d44b6316e49bde912e0de76456d016fd53604f4/.golangci.yml
|
||||||
|
|
||||||
|
version: "2"
|
||||||
|
|
||||||
|
issues:
|
||||||
|
# Maximum count of issues with the same text.
|
||||||
|
# Set to 0 to disable.
|
||||||
|
# Default: 3
|
||||||
|
max-same-issues: 50
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- gofumpt # enforces a stricter format than 'gofmt', while being backwards compatible
|
||||||
|
- gci # checks if code and import statements are formatted, with additional rules
|
||||||
|
- golines # checks if code is formatted, and fixes long lines
|
||||||
|
|
||||||
|
## you may want to enable
|
||||||
|
#- gofmt # checks if the code is formatted according to 'gofmt' command
|
||||||
|
#- goimports # checks if the code and import statements are formatted according to the 'goimports' command
|
||||||
|
#- swaggo # formats swaggo comments
|
||||||
|
|
||||||
|
# All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml
|
||||||
|
settings:
|
||||||
|
gofumpt:
|
||||||
|
# Module path which contains the source code being formatted.
|
||||||
|
module-path: "go.yandata.net/iod/iod/trustlog-sdk"
|
||||||
|
# Choose whether to use the extra rules.
|
||||||
|
extra-rules: true
|
||||||
|
|
||||||
|
gci:
|
||||||
|
sections:
|
||||||
|
- standard # standard packages
|
||||||
|
- default # all imports that could not be matched to another section type
|
||||||
|
- localmodule # all local packages
|
||||||
|
|
||||||
|
golines:
|
||||||
|
# Target maximum line length.
|
||||||
|
# Default: 100
|
||||||
|
max-len: 120
|
||||||
|
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- asasalint # checks for pass []any as any in variadic func(...any)
|
||||||
|
- asciicheck # checks that your code does not contain non-ASCII identifiers
|
||||||
|
- bidichk # checks for dangerous unicode character sequences
|
||||||
|
- bodyclose # checks whether HTTP response body is closed successfully
|
||||||
|
- canonicalheader # checks whether net/http.Header uses canonical header
|
||||||
|
- copyloopvar # detects places where loop variables are copied (Go 1.22+)
|
||||||
|
- cyclop # checks function and package cyclomatic complexity
|
||||||
|
- depguard # checks if package imports are in a list of acceptable packages
|
||||||
|
- dupl # tool for code clone detection
|
||||||
|
- durationcheck # checks for two durations multiplied together
|
||||||
|
- embeddedstructfieldcheck # checks embedded types in structs
|
||||||
|
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||||
|
- errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error
|
||||||
|
- errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13
|
||||||
|
- exhaustive # checks exhaustiveness of enum switch statements
|
||||||
|
- exptostd # detects functions from golang.org/x/exp/ that can be replaced by std functions
|
||||||
|
- fatcontext # detects nested contexts in loops
|
||||||
|
- forbidigo # forbids identifiers
|
||||||
|
- funcorder # checks the order of functions, methods, and constructors
|
||||||
|
- funlen # tool for detection of long functions
|
||||||
|
- gocheckcompilerdirectives # validates go compiler directive comments (//go:)
|
||||||
|
- gochecknoglobals # checks that no global variables exist
|
||||||
|
- gochecknoinits # checks that no init functions are present in Go code
|
||||||
|
- gochecksumtype # checks exhaustiveness on Go "sum types"
|
||||||
|
- gocognit # computes and checks the cognitive complexity of functions
|
||||||
|
- goconst # finds repeated strings that could be replaced by a constant
|
||||||
|
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
||||||
|
- gocyclo # computes and checks the cyclomatic complexity of functions
|
||||||
|
- godot # checks if comments end in a period
|
||||||
|
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
|
||||||
|
- goprintffuncname # checks that printf-like functions are named with f at the end
|
||||||
|
- gosec # inspects source code for security problems
|
||||||
|
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||||
|
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
|
||||||
|
- ineffassign # detects when assignments to existing variables are not used
|
||||||
|
- intrange # finds places where for loops could make use of an integer range
|
||||||
|
- loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap)
|
||||||
|
- makezero # finds slice declarations with non-zero initial length
|
||||||
|
- mirror # reports wrong mirror patterns of bytes/strings usage
|
||||||
|
- mnd # detects magic numbers
|
||||||
|
- musttag # enforces field tags in (un)marshaled structs
|
||||||
|
- nakedret # finds naked returns in functions greater than a specified function length
|
||||||
|
- nestif # reports deeply nested if statements
|
||||||
|
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||||
|
- nilnesserr # reports that it checks for err != nil, but it returns a different nil value error (powered by nilness and nilerr)
|
||||||
|
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||||
|
- noctx # finds sending http request without context.Context
|
||||||
|
- nolintlint # reports ill-formed or insufficient nolint directives
|
||||||
|
# - nonamedreturns # reports all named returns
|
||||||
|
- nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL
|
||||||
|
- perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative
|
||||||
|
- predeclared # finds code that shadows one of Go's predeclared identifiers
|
||||||
|
- promlinter # checks Prometheus metrics naming via promlint
|
||||||
|
- protogetter # reports direct reads from proto message fields when getters should be used
|
||||||
|
- reassign # checks that package variables are not reassigned
|
||||||
|
- recvcheck # checks for receiver type consistency
|
||||||
|
- revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint
|
||||||
|
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||||
|
- sloglint # ensure consistent code style when using log/slog
|
||||||
|
- spancheck # checks for mistakes with OpenTelemetry/Census spans
|
||||||
|
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||||
|
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||||
|
- testableexamples # checks if examples are testable (have an expected output)
|
||||||
|
- testifylint # checks usage of github.com/stretchr/testify
|
||||||
|
- testpackage # makes you use a separate _test package
|
||||||
|
- tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||||
|
- unconvert # removes unnecessary type conversions
|
||||||
|
# - unparam # reports unused function parameters
|
||||||
|
- unused # checks for unused constants, variables, functions and types
|
||||||
|
- usestdlibvars # detects the possibility to use variables/constants from the Go standard library
|
||||||
|
- usetesting # reports uses of functions with replacement inside the testing package
|
||||||
|
- wastedassign # finds wasted assignment statements
|
||||||
|
- whitespace # detects leading and trailing whitespace
|
||||||
|
|
||||||
|
## you may want to enable
|
||||||
|
#- arangolint # opinionated best practices for arangodb client
|
||||||
|
#- decorder # checks declaration order and count of types, constants, variables and functions
|
||||||
|
#- exhaustruct # [highly recommend to enable] checks if all structure fields are initialized
|
||||||
|
#- ginkgolinter # [if you use ginkgo/gomega] enforces standards of using ginkgo and gomega
|
||||||
|
#- godox # detects usage of FIXME, TODO and other keywords inside comments
|
||||||
|
#- goheader # checks is file header matches to pattern
|
||||||
|
#- inamedparam # [great idea, but too strict, need to ignore a lot of cases by default] reports interfaces with unnamed method parameters
|
||||||
|
#- interfacebloat # checks the number of methods inside an interface
|
||||||
|
#- ireturn # accept interfaces, return concrete types
|
||||||
|
#- noinlineerr # disallows inline error handling `if err := ...; err != nil {`
|
||||||
|
#- prealloc # [premature optimization, but can be used in some cases] finds slice declarations that could potentially be preallocated
|
||||||
|
#- tagalign # checks that struct tags are well aligned
|
||||||
|
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
|
||||||
|
#- wrapcheck # checks that errors returned from external packages are wrapped
|
||||||
|
#- zerologlint # detects the wrong usage of zerolog that a user forgets to dispatch zerolog.Event
|
||||||
|
|
||||||
|
## disabled
|
||||||
|
#- containedctx # detects struct contained context.Context field
|
||||||
|
#- contextcheck # [too many false positives] checks the function whether use a non-inherited context
|
||||||
|
#- dogsled # checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||||
|
#- dupword # [useless without config] checks for duplicate words in the source code
|
||||||
|
#- err113 # [too strict] checks the errors handling expressions
|
||||||
|
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted
|
||||||
|
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
|
||||||
|
#- gomodguard # [use more powerful depguard] allow and block lists linter for direct Go module dependencies
|
||||||
|
#- gosmopolitan # reports certain i18n/l10n anti-patterns in your Go codebase
|
||||||
|
#- grouper # analyzes expression groups
|
||||||
|
#- importas # enforces consistent import aliases
|
||||||
|
#- lll # [replaced by golines] reports long lines
|
||||||
|
#- maintidx # measures the maintainability index of each function
|
||||||
|
#- misspell # [useless] finds commonly misspelled English words in comments
|
||||||
|
#- nlreturn # [too strict and mostly code is not more readable] checks for a new line before return and branch statements to increase code clarity
|
||||||
|
#- paralleltest # [too many false positives] detects missing usage of t.Parallel() method in your Go test
|
||||||
|
#- tagliatelle # checks the struct tags
|
||||||
|
#- thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||||
|
#- wsl # [too strict and mostly code is not more readable] whitespace linter forces you to use empty lines
|
||||||
|
#- wsl_v5 # [too strict and mostly code is not more readable] add or remove empty lines
|
||||||
|
|
||||||
|
# All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml
|
||||||
|
settings:
|
||||||
|
cyclop:
|
||||||
|
# The maximal code complexity to report.
|
||||||
|
# Default: 10
|
||||||
|
max-complexity: 30
|
||||||
|
# The maximal average package complexity.
|
||||||
|
# If it's higher than 0.0 (float) the check is enabled.
|
||||||
|
# Default: 0.0
|
||||||
|
package-average: 10.0
|
||||||
|
|
||||||
|
depguard:
|
||||||
|
# Rules to apply.
|
||||||
|
#
|
||||||
|
# Variables:
|
||||||
|
# - File Variables
|
||||||
|
# Use an exclamation mark `!` to negate a variable.
|
||||||
|
# Example: `!$test` matches any file that is not a go test file.
|
||||||
|
#
|
||||||
|
# `$all` - matches all go files
|
||||||
|
# `$test` - matches all go test files
|
||||||
|
#
|
||||||
|
# - Package Variables
|
||||||
|
#
|
||||||
|
# `$gostd` - matches all of go's standard library (Pulled from `GOROOT`)
|
||||||
|
#
|
||||||
|
# Default (applies if no custom rules are defined): Only allow $gostd in all files.
|
||||||
|
rules:
|
||||||
|
"deprecated":
|
||||||
|
# List of file globs that will match this list of settings to compare against.
|
||||||
|
# By default, if a path is relative, it is relative to the directory where the golangci-lint command is executed.
|
||||||
|
# The placeholder '${base-path}' is substituted with a path relative to the mode defined with `run.relative-path-mode`.
|
||||||
|
# The placeholder '${config-path}' is substituted with a path relative to the configuration file.
|
||||||
|
# Default: $all
|
||||||
|
files:
|
||||||
|
- "$all"
|
||||||
|
# List of packages that are not allowed.
|
||||||
|
# Entries can be a variable (starting with $), a string prefix, or an exact match (if ending with $).
|
||||||
|
# Default: []
|
||||||
|
deny:
|
||||||
|
- pkg: github.com/golang/protobuf
|
||||||
|
desc: Use google.golang.org/protobuf instead, see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
|
||||||
|
- pkg: github.com/satori/go.uuid
|
||||||
|
desc: Use github.com/google/uuid instead, satori's package is not maintained
|
||||||
|
- pkg: github.com/gofrs/uuid$
|
||||||
|
desc: Use github.com/gofrs/uuid/v5 or later, it was not a go module before v5
|
||||||
|
"non-test files":
|
||||||
|
files:
|
||||||
|
- "!$test"
|
||||||
|
deny:
|
||||||
|
- pkg: math/rand$
|
||||||
|
desc: Use math/rand/v2 instead, see https://go.dev/blog/randv2
|
||||||
|
"non-main files":
|
||||||
|
files:
|
||||||
|
- "!**/main.go"
|
||||||
|
deny:
|
||||||
|
- pkg: log$
|
||||||
|
desc: Use log/slog instead, see https://go.dev/blog/slog
|
||||||
|
|
||||||
|
embeddedstructfieldcheck:
|
||||||
|
# Checks that sync.Mutex and sync.RWMutex are not used as embedded fields.
|
||||||
|
# Default: false
|
||||||
|
forbid-mutex: true
|
||||||
|
|
||||||
|
errcheck:
|
||||||
|
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||||
|
# Such cases aren't reported by default.
|
||||||
|
# Default: false
|
||||||
|
check-type-assertions: true
|
||||||
|
|
||||||
|
exhaustive:
|
||||||
|
# Program elements to check for exhaustiveness.
|
||||||
|
# Default: [ switch ]
|
||||||
|
check:
|
||||||
|
- switch
|
||||||
|
- map
|
||||||
|
|
||||||
|
exhaustruct:
|
||||||
|
# List of regular expressions to exclude struct packages and their names from checks.
|
||||||
|
# Regular expressions must match complete canonical struct package/name/structname.
|
||||||
|
# Default: []
|
||||||
|
exclude:
|
||||||
|
# std libs
|
||||||
|
- ^net/http.Client$
|
||||||
|
- ^net/http.Cookie$
|
||||||
|
- ^net/http.Request$
|
||||||
|
- ^net/http.Response$
|
||||||
|
- ^net/http.Server$
|
||||||
|
- ^net/http.Transport$
|
||||||
|
- ^net/url.URL$
|
||||||
|
- ^os/exec.Cmd$
|
||||||
|
- ^reflect.StructField$
|
||||||
|
# public libs
|
||||||
|
- ^github.com/Shopify/sarama.Config$
|
||||||
|
- ^github.com/Shopify/sarama.ProducerMessage$
|
||||||
|
- ^github.com/mitchellh/mapstructure.DecoderConfig$
|
||||||
|
- ^github.com/prometheus/client_golang/.+Opts$
|
||||||
|
- ^github.com/spf13/cobra.Command$
|
||||||
|
- ^github.com/spf13/cobra.CompletionOptions$
|
||||||
|
- ^github.com/stretchr/testify/mock.Mock$
|
||||||
|
- ^github.com/testcontainers/testcontainers-go.+Request$
|
||||||
|
- ^github.com/testcontainers/testcontainers-go.FromDockerfile$
|
||||||
|
- ^golang.org/x/tools/go/analysis.Analyzer$
|
||||||
|
- ^google.golang.org/protobuf/.+Options$
|
||||||
|
- ^gopkg.in/yaml.v3.Node$
|
||||||
|
|
||||||
|
funcorder:
|
||||||
|
# Checks if the exported methods of a structure are placed before the non-exported ones.
|
||||||
|
# Default: true
|
||||||
|
struct-method: false
|
||||||
|
|
||||||
|
funlen:
|
||||||
|
# Checks the number of lines in a function.
|
||||||
|
# If lower than 0, disable the check.
|
||||||
|
# Default: 60
|
||||||
|
lines: 100
|
||||||
|
# Checks the number of statements in a function.
|
||||||
|
# If lower than 0, disable the check.
|
||||||
|
# Default: 40
|
||||||
|
statements: 50
|
||||||
|
|
||||||
|
gochecksumtype:
|
||||||
|
# Presence of `default` case in switch statements satisfies exhaustiveness, if all members are not listed.
|
||||||
|
# Default: true
|
||||||
|
default-signifies-exhaustive: false
|
||||||
|
|
||||||
|
gocognit:
|
||||||
|
# Minimal code complexity to report.
|
||||||
|
# Default: 30 (but we recommend 10-20)
|
||||||
|
min-complexity: 20
|
||||||
|
|
||||||
|
gocritic:
|
||||||
|
# Settings passed to gocritic.
|
||||||
|
# The settings key is the name of a supported gocritic checker.
|
||||||
|
# The list of supported checkers can be found at https://go-critic.com/overview.
|
||||||
|
settings:
|
||||||
|
captLocal:
|
||||||
|
# Whether to restrict checker to params only.
|
||||||
|
# Default: true
|
||||||
|
paramsOnly: false
|
||||||
|
underef:
|
||||||
|
# Whether to skip (*x).method() calls where x is a pointer receiver.
|
||||||
|
# Default: true
|
||||||
|
skipRecvDeref: false
|
||||||
|
|
||||||
|
govet:
|
||||||
|
# Enable all analyzers.
|
||||||
|
# Default: false
|
||||||
|
enable-all: true
|
||||||
|
# Disable analyzers by name.
|
||||||
|
# Run `GL_DEBUG=govet golangci-lint run --enable=govet` to see default, all available analyzers, and enabled analyzers.
|
||||||
|
# Default: []
|
||||||
|
disable:
|
||||||
|
- fieldalignment # too strict
|
||||||
|
# Settings per analyzer.
|
||||||
|
settings: {}
|
||||||
|
# shadow:
|
||||||
|
# # Whether to be strict about shadowing; can be noisy.
|
||||||
|
# # Default: false
|
||||||
|
# strict: true
|
||||||
|
|
||||||
|
inamedparam:
|
||||||
|
# Skips check for interface methods with only a single parameter.
|
||||||
|
# Default: false
|
||||||
|
skip-single-param: true
|
||||||
|
|
||||||
|
mnd:
|
||||||
|
# List of function patterns to exclude from analysis.
|
||||||
|
# Values always ignored: `time.Date`,
|
||||||
|
# `strconv.FormatInt`, `strconv.FormatUint`, `strconv.FormatFloat`,
|
||||||
|
# `strconv.ParseInt`, `strconv.ParseUint`, `strconv.ParseFloat`.
|
||||||
|
# Default: []
|
||||||
|
ignored-functions:
|
||||||
|
- args.Error
|
||||||
|
- flag.Arg
|
||||||
|
- flag.Duration.*
|
||||||
|
- flag.Float.*
|
||||||
|
- flag.Int.*
|
||||||
|
- flag.Uint.*
|
||||||
|
- os.Chmod
|
||||||
|
- os.Mkdir.*
|
||||||
|
- os.OpenFile
|
||||||
|
- os.WriteFile
|
||||||
|
- prometheus.ExponentialBuckets.*
|
||||||
|
- prometheus.LinearBuckets
|
||||||
|
|
||||||
|
nakedret:
|
||||||
|
# Make an issue if func has more lines of code than this setting, and it has naked returns.
|
||||||
|
# Default: 30
|
||||||
|
max-func-lines: 0
|
||||||
|
|
||||||
|
nolintlint:
|
||||||
|
# Exclude following linters from requiring an explanation.
|
||||||
|
# Default: []
|
||||||
|
allow-no-explanation: [ funlen, gocognit, golines ]
|
||||||
|
# Enable to require an explanation of nonzero length after each nolint directive.
|
||||||
|
# Default: false
|
||||||
|
require-explanation: true
|
||||||
|
# Enable to require nolint directives to mention the specific linter being suppressed.
|
||||||
|
# Default: false
|
||||||
|
require-specific: true
|
||||||
|
|
||||||
|
perfsprint:
|
||||||
|
# Optimizes into strings concatenation.
|
||||||
|
# Default: true
|
||||||
|
strconcat: false
|
||||||
|
|
||||||
|
reassign:
|
||||||
|
# Patterns for global variable names that are checked for reassignment.
|
||||||
|
# See https://github.com/curioswitch/go-reassign#usage
|
||||||
|
# Default: ["EOF", "Err.*"]
|
||||||
|
patterns:
|
||||||
|
- ".*"
|
||||||
|
|
||||||
|
revive:
|
||||||
|
rules:
|
||||||
|
- name: blank-imports
|
||||||
|
disabled: true
|
||||||
|
- name: unused-parameter
|
||||||
|
arguments:
|
||||||
|
- allow-regex: "^_|ctx|req"
|
||||||
|
|
||||||
|
rowserrcheck:
|
||||||
|
# database/sql is always checked.
|
||||||
|
# Default: []
|
||||||
|
packages:
|
||||||
|
- github.com/jmoiron/sqlx
|
||||||
|
|
||||||
|
sloglint:
|
||||||
|
# Enforce not using global loggers.
|
||||||
|
# Values:
|
||||||
|
# - "": disabled
|
||||||
|
# - "all": report all global loggers
|
||||||
|
# - "default": report only the default slog logger
|
||||||
|
# https://github.com/go-simpler/sloglint?tab=readme-ov-file#no-global
|
||||||
|
# Default: ""
|
||||||
|
no-global: all
|
||||||
|
# Enforce using methods that accept a context.
|
||||||
|
# Values:
|
||||||
|
# - "": disabled
|
||||||
|
# - "all": report all contextless calls
|
||||||
|
# - "scope": report only if a context exists in the scope of the outermost function
|
||||||
|
# https://github.com/go-simpler/sloglint?tab=readme-ov-file#context-only
|
||||||
|
# Default: ""
|
||||||
|
context: scope
|
||||||
|
|
||||||
|
staticcheck:
|
||||||
|
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
|
||||||
|
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
|
||||||
|
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
|
||||||
|
checks:
|
||||||
|
- all
|
||||||
|
# Incorrect or missing package comment.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1000
|
||||||
|
- -ST1000
|
||||||
|
# Use consistent method receiver names.
|
||||||
|
# https://staticcheck.dev/docs/checks/#ST1016
|
||||||
|
- -ST1016
|
||||||
|
# Omit embedded fields from selector expression.
|
||||||
|
# https://staticcheck.dev/docs/checks/#QF1008
|
||||||
|
- -QF1008
|
||||||
|
|
||||||
|
usetesting:
|
||||||
|
# Enable/disable `os.TempDir()` detections.
|
||||||
|
# Default: false
|
||||||
|
os-temp-dir: true
|
||||||
|
|
||||||
|
exclusions:
|
||||||
|
# Log a warning if an exclusion rule is unused.
|
||||||
|
# Default: false
|
||||||
|
warn-unused: true
|
||||||
|
# Predefined exclusion rules.
|
||||||
|
# Default: []
|
||||||
|
presets:
|
||||||
|
- std-error-handling
|
||||||
|
- common-false-positives
|
||||||
|
# Excluding configuration per-path, per-linter, per-text and per-source.
|
||||||
|
rules:
|
||||||
|
- source: 'TODO'
|
||||||
|
linters: [ godot ]
|
||||||
|
- text: 'should have a package comment'
|
||||||
|
linters: [ revive ]
|
||||||
|
- text: 'exported \S+ \S+ should have comment( \(or a comment on this block\))? or be unexported'
|
||||||
|
linters: [ revive ]
|
||||||
|
- text: 'package comment should be of the form ".+"'
|
||||||
|
source: '// ?(nolint|TODO)'
|
||||||
|
linters: [ revive ]
|
||||||
|
- text: 'comment on exported \S+ \S+ should be of the form ".+"'
|
||||||
|
source: '// ?(nolint|TODO)'
|
||||||
|
linters: [ revive, staticcheck ]
|
||||||
|
# 忽略弱加密算法警告 - 这些算法用于业务兼容性需求
|
||||||
|
- path: 'internal/model/entity/hash\.go'
|
||||||
|
text: '(G401|G501|G505|G506|G507)'
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
- text: 'Blocklisted import (crypto/md5|crypto/sha1): weak cryptographic primitive'
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
- text: 'Blocklisted import (golang\.org/x/crypto/md4|golang\.org/x/crypto/ripemd160): deprecated and weak cryptographic primitive'
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
- text: 'Use of (weak|deprecated weak) cryptographic primitive'
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
# Allow shadowed variables named "err"
|
||||||
|
- text: 'shadow: declaration of "err" shadows declaration at line \d+'
|
||||||
|
linters: [ govet ]
|
||||||
|
# Ignore unused variables in test/mock files
|
||||||
|
- path: '.+_(test|mock)\.go'
|
||||||
|
text: "unused-parameter: parameter '.+' seems to be unused, consider removing or renaming it to match"
|
||||||
|
linters: [ revive ]
|
||||||
|
# Allow xx/utils, xx/config, xx/client, etc.
|
||||||
|
- text: 'avoid meaningless package names'
|
||||||
|
linters: [ revive ]
|
||||||
12
.lefthook.yml
Normal file
12
.lefthook.yml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
assert_lefthook_installed: true
|
||||||
|
|
||||||
|
pre-commit:
|
||||||
|
parallel: true
|
||||||
|
commands:
|
||||||
|
pre-commit:
|
||||||
|
run: pre-commit run --files {staged_files}
|
||||||
|
stage_fixed: true
|
||||||
|
lint:
|
||||||
|
glob: "*.go"
|
||||||
|
run: golangci-lint run --fix {staged_files}
|
||||||
|
stage_fixed: true
|
||||||
39
.pre-commit-config.yaml
Normal file
39
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
default_stages: [commit, push]
|
||||||
|
fail_fast: true # stop running hooks at the first failure
|
||||||
|
minimum_pre_commit_version: 4.2.0
|
||||||
|
|
||||||
|
exclude: |
|
||||||
|
(?x)(
|
||||||
|
cookiecutter-config-file.yml|
|
||||||
|
.cruft.json
|
||||||
|
)
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
# Verify syntax
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-json
|
||||||
|
- id: check-xml
|
||||||
|
|
||||||
|
# Checkers
|
||||||
|
- id: check-merge-conflict # check for merge conflict string
|
||||||
|
- id: detect-private-key # check for existence of private keys
|
||||||
|
|
||||||
|
# Implicit minor corrections to files
|
||||||
|
- id: end-of-file-fixer # ensure all files end with a new line
|
||||||
|
- id: trailing-whitespace # trim trailing whitespaces
|
||||||
|
|
||||||
|
# Good practices for executables/scripts
|
||||||
|
- id: check-executables-have-shebangs # non-binary executables have shebang
|
||||||
|
- id: check-shebang-scripts-are-executable # verify shebang scripts are executable
|
||||||
|
|
||||||
|
# Pretty format JSON files implicitly
|
||||||
|
- id: pretty-format-json
|
||||||
|
args: ["--autofix"]
|
||||||
|
|
||||||
|
# Debatable -- for most projects adding files larger than 10 MB is likely to be
|
||||||
|
# a mistake instead of a requirement. Remove this if needed
|
||||||
|
- id: check-added-large-files # fails if a file larger than 10 MB exists
|
||||||
|
args: ["--maxkb=10240", "--enforce-all"]
|
||||||
156
Makefile
Normal file
156
Makefile
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# Makefile - use `make` or `make help` to get a list of commands.
|
||||||
|
#
|
||||||
|
# Note - Comments inside this makefile should be made using a single
|
||||||
|
# hashtag `#`, lines with double hash-tags will be the messages that
|
||||||
|
# printed in the help command
|
||||||
|
|
||||||
|
# Name of the current directory
|
||||||
|
PROJECTNAME="trustlog-sdk"
|
||||||
|
|
||||||
|
# List of all Go-files to be processed
|
||||||
|
GOFILES=$(wildcard *.go)
|
||||||
|
|
||||||
|
# Docker image variables
|
||||||
|
IMAGE := $(PROJECTNAME)
|
||||||
|
VERSION := latest
|
||||||
|
|
||||||
|
# Ensures firing a blank `make` command default to help
|
||||||
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
|
# Make is verbose in Linux. Make it silent
|
||||||
|
MAKEFLAGS += --silent
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: help
|
||||||
|
## `help`: Generates this help dialog for the Makefile
|
||||||
|
help: Makefile
|
||||||
|
echo
|
||||||
|
echo " Commands available in \`"$(PROJECTNAME)"\`:"
|
||||||
|
echo
|
||||||
|
sed -n 's/^[ \t]*##//p' $< | column -t -s ':' | sed -e 's/^//'
|
||||||
|
echo
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: local-setup
|
||||||
|
## `local-setup`: Setup development environment locally
|
||||||
|
local-setup:
|
||||||
|
echo " > Ensuring directory is a git repository"
|
||||||
|
git init &> /dev/null
|
||||||
|
echo " > Installing pre-commit"
|
||||||
|
pip install --upgrade pre-commit &> /dev/null
|
||||||
|
pre-commit install
|
||||||
|
|
||||||
|
|
||||||
|
# Will install missing dependencies
|
||||||
|
.PHONY: install
|
||||||
|
## `install`: Fetch dependencies needed to run `trustlog-sdk`
|
||||||
|
install:
|
||||||
|
echo " > Getting dependencies..."
|
||||||
|
go get -v $(get)
|
||||||
|
go mod tidy
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: codestyle
|
||||||
|
## :
|
||||||
|
## `codestyle`: Run code formatter(s)
|
||||||
|
codestyle:
|
||||||
|
golangci-lint run --fix
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: lint
|
||||||
|
## `lint`: Run linters and check code-style
|
||||||
|
lint:
|
||||||
|
golangci-lint run
|
||||||
|
|
||||||
|
|
||||||
|
# No `help` message for this command - designed to be consumed internally
|
||||||
|
.PHONY: --test-runner
|
||||||
|
--test-runner:
|
||||||
|
go test ./... -race -covermode=atomic -coverprofile=./coverage/coverage.txt
|
||||||
|
go tool cover -html=./coverage/coverage.txt -o ./coverage/coverage.html
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
## :
|
||||||
|
## `test`: Run all tests
|
||||||
|
test: export TEST_MODE=complete
|
||||||
|
test: --test-runner
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: fast-tests
|
||||||
|
## `fast-tests`: Selectively run fast tests
|
||||||
|
fast-tests: export TEST_MODE=fast
|
||||||
|
fast-tests: --test-runner
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: slow-tests
|
||||||
|
## `slow-tests`: Selectively run slow tests
|
||||||
|
slow-tests: export TEST_MODE=slow
|
||||||
|
slow-tests: --test-runner
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test-suite
|
||||||
|
## `test-suite`: Check code style, run linters and ALL tests
|
||||||
|
test-suite: export TEST_MODE=complete
|
||||||
|
test-suite: lint test
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: run
|
||||||
|
## :
|
||||||
|
## `run`: Run `trustlog-sdk` in production mode
|
||||||
|
run: export production_mode=production
|
||||||
|
run: export __BUILD_MODE__=production
|
||||||
|
run:
|
||||||
|
go run main.go $(q)
|
||||||
|
|
||||||
|
.PHONY: run-debug
|
||||||
|
## `run-debug`: Run `trustlog-sdk` in debug mode
|
||||||
|
run-debug: export debug_mode=debug
|
||||||
|
run-debug: export __BUILD_MODE__=debug
|
||||||
|
run-debug:
|
||||||
|
go run main.go $(q)
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: docker-gen
|
||||||
|
## :
|
||||||
|
## `docker-gen`: Create a production docker image for `trustlog-sdk`
|
||||||
|
docker-gen:
|
||||||
|
echo "Building docker image \`$(IMAGE):$(VERSION)\`..."
|
||||||
|
docker build --rm \
|
||||||
|
--build-arg final_image=scratch \
|
||||||
|
--build-arg build_mode=production \
|
||||||
|
-t $(IMAGE):$(VERSION) . \
|
||||||
|
-f ./docker/Dockerfile
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: docker-debug
|
||||||
|
## `docker-debug`: Create debug-friendly docker images for `trustlog-sdk`
|
||||||
|
docker-debug:
|
||||||
|
echo "Building docker image \`$(IMAGE):$(VERSION)\`..."
|
||||||
|
docker build --rm=false \
|
||||||
|
--build-arg final_image=golang:1.24 \
|
||||||
|
--build-arg build_mode=debug \
|
||||||
|
-t $(IMAGE)-debug:$(VERSION) . \
|
||||||
|
-f ./docker/Dockerfile
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: clean-docker
|
||||||
|
## `clean-docker`: Delete an existing docker image
|
||||||
|
clean-docker:
|
||||||
|
echo "Removing docker $(IMAGE):$(VERSION)..."
|
||||||
|
docker rmi -f $(IMAGE):$(VERSION)
|
||||||
|
|
||||||
|
|
||||||
|
## :
|
||||||
|
## NOTE: All docker-related commands can use `IMAGE`
|
||||||
|
## : and `VERSION` variables to modify the docker
|
||||||
|
## : image being targeted
|
||||||
|
## :
|
||||||
|
## : Example;
|
||||||
|
## : make docker-gen IMAGE=new_project VERSION=3.15
|
||||||
|
## :
|
||||||
|
## : Likewise, both the `run` commands can pass runtime
|
||||||
|
## : arguments under the `q` arg
|
||||||
|
## :
|
||||||
|
## : Example;
|
||||||
|
## : `make run q="time --version"`
|
||||||
967
README.md
Normal file
967
README.md
Normal file
@@ -0,0 +1,967 @@
|
|||||||
|
# Trustlog-SDK 使用说明
|
||||||
|
|
||||||
|
本 SDK 提供基于 [Watermill](https://watermill.io/) 抽象层的统一消息发送与接收能力,以及基于 gRPC 的操作查询和取证验证功能。
|
||||||
|
|
||||||
|
SDK 支持两种数据模型:
|
||||||
|
- **`Operation`**(操作记录):用于记录完整的业务操作,包含请求/响应体哈希,支持完整的取证验证
|
||||||
|
- **`Record`**(简单记录):用于记录简单的事件或日志,轻量级,适合日志和事件追踪场景
|
||||||
|
|
||||||
|
两种模型分别发布到不同的 Topic,通过统一的 `HighClient` 和 `QueryClient` 进行操作。支持通过 Watermill Forwarder 将消息持久化到 SQL 数据库,实现事务性保证。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🚀 安装
|
||||||
|
|
||||||
|
### 1. 私有仓库配置(重要)
|
||||||
|
|
||||||
|
由于本 SDK 托管在私有仓库,需要配置 SSH 映射和禁用 Go Module 校验:
|
||||||
|
|
||||||
|
#### 配置 Git SSH 映射(跳过 HTTPS 验证)
|
||||||
|
```bash
|
||||||
|
git config --global url."git@go.yandata.net:".insteadOf "https://go.yandata.net"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 禁用 Go Module Sum 校验
|
||||||
|
```bash
|
||||||
|
go env -w GOPRIVATE="go.yandata.net"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 安装 SDK
|
||||||
|
```bash
|
||||||
|
go get go.yandata.net/iod/iod/go-trustlog
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📦 核心概念
|
||||||
|
|
||||||
|
### 数据模型
|
||||||
|
|
||||||
|
SDK 提供两种数据模型,分别适用于不同的业务场景:
|
||||||
|
|
||||||
|
#### 1. Operation(操作记录)
|
||||||
|
|
||||||
|
`Operation` 用于记录完整的业务操作,包含完整的元数据、请求/响应体哈希等信息,支持完整的取证验证流程。
|
||||||
|
|
||||||
|
**适用场景**:
|
||||||
|
- 记录 DOIP/IRP 协议的完整操作(Create、Update、Delete、Retrieve 等)
|
||||||
|
- 需要完整记录请求和响应的审计场景
|
||||||
|
- 需要支持完整取证验证的操作记录
|
||||||
|
|
||||||
|
**核心字段**:
|
||||||
|
- `Meta`:操作元数据
|
||||||
|
- `OpID`:操作唯一标识符(自动生成 UUID v7)
|
||||||
|
- `Timestamp`:操作时间戳(必填)
|
||||||
|
- `OpSource`:操作来源(`DOIP` 或 `IRP`)
|
||||||
|
- `OpType`:操作类型(如 `Create`、`Update`、`Delete` 等)
|
||||||
|
- `OpAlgorithm`:哈希算法类型(默认 `Sha256Simd`)
|
||||||
|
- `OpMetaHash`:元数据哈希值(自动计算)
|
||||||
|
- `DataID`:数据标识
|
||||||
|
- `DoPrefix`:DO 前缀(必填)
|
||||||
|
- `DoRepository`:仓库名(必填)
|
||||||
|
- `Doid`:完整 DOID(必填,格式:`{DoPrefix}/{DoRepository}/{object}`)
|
||||||
|
- `OpActor`:操作发起者(默认 `SYSTEM`)
|
||||||
|
- `RequestBodyHash`:请求体哈希值(必填)
|
||||||
|
- `ResponseBodyHash`:响应体哈希值(必填)
|
||||||
|
- `OpHash`:操作整体哈希值(自动计算)
|
||||||
|
|
||||||
|
**创建方式**:
|
||||||
|
```go
|
||||||
|
op, err := model.NewFullOperation(
|
||||||
|
model.OpSourceDOIP, // 操作来源
|
||||||
|
model.OpTypeCreate, // 操作类型
|
||||||
|
dataID, // 数据标识
|
||||||
|
"user123", // 操作者
|
||||||
|
[]byte(`{"foo":"bar"}`), // 请求体(支持 string 或 []byte)
|
||||||
|
[]byte(`{"status":"ok"}`), // 响应体(支持 string 或 []byte)
|
||||||
|
model.SHA256, // 哈希算法
|
||||||
|
time.Now(), // 操作时间戳
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**发布方式**:
|
||||||
|
```go
|
||||||
|
client.OperationPublish(op) // 发布到 OperationTopic
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Record(简单记录)
|
||||||
|
|
||||||
|
`Record` 用于记录简单的事件或日志,轻量级设计,适合日志和事件追踪场景。
|
||||||
|
|
||||||
|
**适用场景**:
|
||||||
|
- 记录简单的日志信息
|
||||||
|
- 记录系统中的事件(如用户登录、配置变更等)
|
||||||
|
- 不需要完整请求/响应信息的轻量级记录场景
|
||||||
|
|
||||||
|
**核心字段**:
|
||||||
|
- `ID`:记录唯一标识符(自动生成 UUID v7)
|
||||||
|
- `DoPrefix`:节点前缀(可选)
|
||||||
|
- `Timestamp`:操作时间(可选,默认当前时间)
|
||||||
|
- `Operator`:用户标识(可选)
|
||||||
|
- `Extra`:额外数据(可选,`[]byte` 类型)
|
||||||
|
- `RCType`:记录类型(可选,如 `"log"`、`"event"` 等)
|
||||||
|
- `Algorithm`:哈希算法类型(默认 `Sha256Simd`)
|
||||||
|
- `RCHash`:记录哈希值(自动计算)
|
||||||
|
|
||||||
|
**创建方式**:
|
||||||
|
```go
|
||||||
|
// 方式一:完整创建
|
||||||
|
record, err := model.NewFullRecord(
|
||||||
|
"10.1000", // DoPrefix
|
||||||
|
time.Now(), // 时间戳
|
||||||
|
"operator123", // 操作者
|
||||||
|
[]byte("extra data"), // 额外数据
|
||||||
|
"log", // 记录类型
|
||||||
|
model.BLAKE3, // 哈希算法
|
||||||
|
)
|
||||||
|
|
||||||
|
// 方式二:链式调用创建
|
||||||
|
record, _ := model.NewRecord(model.SHA256)
|
||||||
|
record.WithDoPrefix("10.1000").
|
||||||
|
WithTimestamp(time.Now()).
|
||||||
|
WithOperator("operator123").
|
||||||
|
WithExtra([]byte("extra data")).
|
||||||
|
WithRCType("log")
|
||||||
|
```
|
||||||
|
|
||||||
|
**发布方式**:
|
||||||
|
```go
|
||||||
|
client.RecordPublish(record) // 发布到 RecordTopic
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 两种模型的对比
|
||||||
|
|
||||||
|
| 特性 | Operation | Record |
|
||||||
|
|------|-----------|--------|
|
||||||
|
| **用途** | 完整业务操作记录 | 简单事件/日志记录 |
|
||||||
|
| **请求/响应** | ✅ 包含请求体和响应体哈希 | ❌ 不包含 |
|
||||||
|
| **取证验证** | ✅ 完整取证验证流程 | ✅ 哈希验证 |
|
||||||
|
| **数据标识** | ✅ 完整的 DataID(Prefix/Repository/Doid) | ✅ 可选的 DoPrefix |
|
||||||
|
| **字段复杂度** | 较高(8+ 字段) | 较低(7 字段) |
|
||||||
|
| **Topic** | `persistent://public/default/operation` | `persistent://public/default/record` |
|
||||||
|
| **适用场景** | 审计、完整操作追踪 | 日志、事件追踪 |
|
||||||
|
|
||||||
|
### HashType(哈希算法)
|
||||||
|
|
||||||
|
两种模型都支持以下 18 种哈希算法:
|
||||||
|
- **MD5 系列**:`MD5`、`MD4`
|
||||||
|
- **SHA 系列**:`SHA1`、`SHA224`、`SHA256`、`SHA384`、`SHA512`、`SHA512/224`、`SHA512/256`、`SHA256-SIMD`
|
||||||
|
- **SHA3 系列**:`SHA3-224`、`SHA3-256`、`SHA3-384`、`SHA3-512`
|
||||||
|
- **BLAKE 系列**:`BLAKE3`、`BLAKE2B`、`BLAKE2S`
|
||||||
|
- **其他**:`RIPEMD160`
|
||||||
|
|
||||||
|
默认算法:`Sha256Simd`
|
||||||
|
|
||||||
|
### 组件说明
|
||||||
|
- **Publisher**
|
||||||
|
负责将 `Operation` 或 `Record` 序列化并发布到对应的 Topic:
|
||||||
|
- `Operation` → `persistent://public/default/operation`
|
||||||
|
- `Record` → `persistent://public/default/record`
|
||||||
|
|
||||||
|
- **Subscriber**
|
||||||
|
负责从 Topic 中订阅报文并进行 ack/nack 处理(一般无需直接使用)。可以订阅 `OperationTopic` 或 `RecordTopic`。
|
||||||
|
|
||||||
|
- **HighClient**
|
||||||
|
高层封装的发布客户端,方便业务代码发送 `Operation` 和 `Record` 消息。
|
||||||
|
|
||||||
|
- **QueryClient**
|
||||||
|
基于 gRPC 的统一查询客户端,提供:
|
||||||
|
- **Operation 操作查询**:列表查询和取证验证
|
||||||
|
- **Record 记录查询**:列表查询和验证
|
||||||
|
- **单一连接池**:两种服务共享同一组 gRPC 连接,支持多服务器负载均衡
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎯 使用场景
|
||||||
|
|
||||||
|
### 发布场景
|
||||||
|
|
||||||
|
#### Operation 发布场景
|
||||||
|
- **业务操作记录**:记录 DOIP/IRP 协议的完整操作(Create、Update、Delete 等)
|
||||||
|
- **审计追踪**:需要完整记录请求和响应的审计场景
|
||||||
|
- **取证验证**:需要支持完整取证验证的操作记录
|
||||||
|
|
||||||
|
#### Record 发布场景
|
||||||
|
- **日志记录**:记录简单的日志信息
|
||||||
|
- **事件追踪**:记录系统中的事件(如用户登录、配置变更等)
|
||||||
|
- **轻量级记录**:不需要完整请求/响应信息的场景
|
||||||
|
|
||||||
|
**发布方式**:
|
||||||
|
- **直接发布**:使用 Pulsar Publisher(SDK 已提供)发送到对应的 Pulsar 主题
|
||||||
|
- **事务性发布**:使用 Watermill Forwarder 将消息持久化到 SQL 数据库,保证消息的事务性和可靠性
|
||||||
|
|
||||||
|
### 查询场景
|
||||||
|
|
||||||
|
#### Operation 查询场景
|
||||||
|
- **操作列表查询**:查询历史操作记录列表(支持分页、按来源/类型/前缀/仓库过滤)
|
||||||
|
- **取证验证**:对特定操作执行完整的取证验证(流式返回进度)
|
||||||
|
|
||||||
|
#### Record 查询场景
|
||||||
|
- **记录列表查询**:查询历史记录列表(支持分页、按前缀和类型过滤)
|
||||||
|
- **记录验证**:对特定记录执行哈希验证(流式返回进度)
|
||||||
|
|
||||||
|
**统一客户端**:`QueryClient` 使用单一连接池同时支持两种服务,共享 gRPC 连接资源
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📝 快速开始
|
||||||
|
|
||||||
|
### 1. HighClient 使用(消息发布)
|
||||||
|
|
||||||
|
#### 1.1 创建 Logger
|
||||||
|
|
||||||
|
SDK 使用 [logr](https://github.com/go-logr/logr) 作为日志接口。你需要先创建一个 logr.Logger 实例,然后通过 `logger.NewLogger()` 包装成 SDK 的 Logger 接口。
|
||||||
|
|
||||||
|
##### 方式一:使用默认的 discard logger(适用于测试)
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 使用 discard logger(不输出任何日志)
|
||||||
|
myLogger := logger.NewLogger(logr.Discard())
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 方式二:使用 zap(推荐生产环境)
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
"github.com/go-logr/zap"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 创建 zap logger
|
||||||
|
zapLogger, _ := zap.NewProduction()
|
||||||
|
// 转换为 logr.Logger
|
||||||
|
logrLogger := zapr.NewLogger(zapLogger)
|
||||||
|
// 包装成 SDK 的 Logger
|
||||||
|
myLogger := logger.NewLogger(logrLogger)
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 方式三:使用其他 logr 实现
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
// 可以使用任何实现了 logr.LogSink 的实现
|
||||||
|
// 例如:github.com/go-logr/logr/slogr(基于 slog)
|
||||||
|
// github.com/go-logr/zap(基于 zap)
|
||||||
|
// github.com/go-logr/logrusr(基于 logrus)
|
||||||
|
)
|
||||||
|
|
||||||
|
// 假设你有一个 logr.Logger 实例
|
||||||
|
var logrLogger logr.Logger
|
||||||
|
myLogger := logger.NewLogger(logrLogger)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.2 创建 Publisher
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 创建 Logger(使用 discard 作为示例)
|
||||||
|
myLogger := logger.NewLogger(logr.Discard())
|
||||||
|
|
||||||
|
// 创建 Pulsar Publisher
|
||||||
|
pub, err := adapter.NewPublisher(
|
||||||
|
adapter.PublisherConfig{
|
||||||
|
URL: "pulsar://localhost:6650",
|
||||||
|
},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer pub.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.3 使用 HighClient 发送 Operation
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/highclient"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/model"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 准备SM2密钥(十六进制字符串格式)
|
||||||
|
privateKeyHex := []byte("私钥D的十六进制字符串,例如:abc123...")
|
||||||
|
publicKeyHex := []byte("04 + x坐标(32字节) + y坐标(32字节)的十六进制字符串")
|
||||||
|
|
||||||
|
// 创建Envelope配置
|
||||||
|
envelopeConfig := model.DefaultEnvelopeConfig(privateKeyHex, publicKeyHex)
|
||||||
|
|
||||||
|
// 创建高层客户端(使用Envelope序列化方式)
|
||||||
|
client := highclient.NewClient(pub, myLogger, envelopeConfig)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
// 构造 DataID
|
||||||
|
dataID := model.DataID{
|
||||||
|
DoPrefix: "10.1000",
|
||||||
|
DoRepository: "my-repo",
|
||||||
|
Doid: "10.1000/my-repo/object123",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造完整的 Operation
|
||||||
|
op, err := model.NewFullOperation(
|
||||||
|
model.OpSourceDOIP, // 操作来源:DOIP 或 IRP
|
||||||
|
model.OpTypeCreate, // 操作类型:Create, Update, Delete 等
|
||||||
|
dataID, // 数据标识
|
||||||
|
"user123", // 操作者
|
||||||
|
[]byte(`{"foo":"bar"}`), // 请求体
|
||||||
|
[]byte(`{"status":"ok"}`), // 响应体
|
||||||
|
model.Sha256Simd, // 哈希算法
|
||||||
|
time.Now(), // 操作时间
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送 Operation
|
||||||
|
if err := client.OperationPublish(op); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.4 使用 HighClient 发送 Record
|
||||||
|
```go
|
||||||
|
// 构造 Record
|
||||||
|
record, err := model.NewFullRecord(
|
||||||
|
"10.1000", // DoPrefix
|
||||||
|
time.Now(), // 时间戳
|
||||||
|
"operator123", // 操作者
|
||||||
|
[]byte("extra data"), // 额外数据
|
||||||
|
"log", // 记录类型
|
||||||
|
model.BLAKE3, // 哈希算法
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送 Record
|
||||||
|
if err := client.RecordPublish(record); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.5 获取底层 Publisher
|
||||||
|
```go
|
||||||
|
// 如果需要直接访问 Watermill Publisher
|
||||||
|
lowPublisher := client.GetLow()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. QueryClient 使用(统一查询客户端)
|
||||||
|
|
||||||
|
`QueryClient` 是统一的查询客户端,同时支持 **Operation(操作)** 和 **Record(记录)** 两种服务的查询和验证。使用单一连接池,两种服务共享同一组 gRPC 连接。
|
||||||
|
|
||||||
|
#### 2.1 创建 QueryClient
|
||||||
|
|
||||||
|
##### 单服务器模式
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/queryclient"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 创建 Logger
|
||||||
|
myLogger := logger.NewLogger(logr.Discard())
|
||||||
|
|
||||||
|
// 创建统一查询客户端(单服务器)
|
||||||
|
queryClient, err := queryclient.NewClient(
|
||||||
|
queryclient.ClientConfig{
|
||||||
|
ServerAddr: "localhost:50051",
|
||||||
|
},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer queryClient.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 多服务器负载均衡模式
|
||||||
|
```go
|
||||||
|
// 创建查询客户端(多服务器,自动轮询负载均衡)
|
||||||
|
queryClient, err := queryclient.NewClient(
|
||||||
|
queryclient.ClientConfig{
|
||||||
|
ServerAddrs: []string{
|
||||||
|
"server1:50051",
|
||||||
|
"server2:50051",
|
||||||
|
"server3:50051",
|
||||||
|
},
|
||||||
|
// DialOptions: []grpc.DialOption{...}, // 可选:自定义 gRPC 连接选项
|
||||||
|
},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer queryClient.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.2 查询操作列表
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 构造查询请求
|
||||||
|
req := queryclient.ListOperationsRequest{
|
||||||
|
PageSize: 100, // 每页数量
|
||||||
|
PreTime: time.Now().Add(-24 * time.Hour), // 游标分页(可选)
|
||||||
|
|
||||||
|
// 可选过滤条件
|
||||||
|
OpSource: model.OpSourceDOIP, // 按操作来源过滤
|
||||||
|
OpType: model.OpTypeCreate, // 按操作类型过滤
|
||||||
|
DoPrefix: "10.1000", // 按数据前缀过滤
|
||||||
|
DoRepository: "my-repo", // 按仓库过滤
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
resp, err := queryClient.ListOperations(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理结果
|
||||||
|
fmt.Printf("Total count: %d\n", resp.Count)
|
||||||
|
for _, op := range resp.Data {
|
||||||
|
fmt.Printf("Operation ID: %s, Type: %s, Time: %s\n",
|
||||||
|
op.Meta.OpID, op.Meta.OpType, op.Meta.Timestamp)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.3 取证验证(流式)
|
||||||
|
```go
|
||||||
|
// 构造验证请求
|
||||||
|
validationReq := queryclient.ValidationRequest{
|
||||||
|
Time: time.Now().Add(-1 * time.Hour),
|
||||||
|
OpID: "operation-id-123",
|
||||||
|
OpType: "Create",
|
||||||
|
DoRepository: "my-repo",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步验证(流式接收进度)
|
||||||
|
resultChan, err := queryClient.ValidateOperation(ctx, validationReq)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理流式结果
|
||||||
|
for result := range resultChan {
|
||||||
|
if result.IsProcessing() {
|
||||||
|
fmt.Printf("Progress: %s - %s\n", result.Progress, result.Msg)
|
||||||
|
} else if result.IsCompleted() {
|
||||||
|
fmt.Println("Validation completed successfully!")
|
||||||
|
if result.Data != nil {
|
||||||
|
fmt.Printf("Operation: %+v\n", result.Data)
|
||||||
|
}
|
||||||
|
} else if result.IsFailed() {
|
||||||
|
fmt.Printf("Validation failed: %s\n", result.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.4 取证验证(同步)
|
||||||
|
```go
|
||||||
|
// 同步验证(阻塞直到完成)
|
||||||
|
finalResult, err := queryClient.ValidateOperationSync(
|
||||||
|
ctx,
|
||||||
|
validationReq,
|
||||||
|
func(progress *model.ValidationResult) {
|
||||||
|
// 可选的进度回调
|
||||||
|
fmt.Printf("Progress: %s\n", progress.Progress)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalResult.IsCompleted() {
|
||||||
|
fmt.Println("Validation successful!")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Validation failed: %s\n", finalResult.Msg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.5 查询记录列表(Record)
|
||||||
|
```go
|
||||||
|
// 构造记录查询请求
|
||||||
|
recordReq := queryclient.ListRecordsRequest{
|
||||||
|
PageSize: 50, // 每页数量
|
||||||
|
PreTime: time.Now().Add(-24 * time.Hour), // 游标分页(可选)
|
||||||
|
|
||||||
|
// 可选过滤条件
|
||||||
|
DoPrefix: "10.1000", // 按数据前缀过滤
|
||||||
|
RCType: "log", // 按记录类型过滤
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
recordResp, err := queryClient.ListRecords(ctx, recordReq)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理结果
|
||||||
|
fmt.Printf("Total records: %d\n", recordResp.Count)
|
||||||
|
for _, rec := range recordResp.Data {
|
||||||
|
fmt.Printf("Record ID: %s, Type: %s, Hash: %s\n",
|
||||||
|
rec.ID, rec.RCType, rec.RCHash)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.6 记录验证(流式)
|
||||||
|
```go
|
||||||
|
// 构造记录验证请求
|
||||||
|
recordValidationReq := queryclient.RecordValidationRequest{
|
||||||
|
Timestamp: time.Now().Add(-1 * time.Hour),
|
||||||
|
RecordID: "record-id-123",
|
||||||
|
DoPrefix: "10.1000",
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步验证(流式接收进度)
|
||||||
|
recordResultChan, err := queryClient.ValidateRecord(ctx, recordValidationReq)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理流式结果
|
||||||
|
for result := range recordResultChan {
|
||||||
|
if result.IsProcessing() {
|
||||||
|
fmt.Printf("Progress: %s - %s\n", result.Progress, result.Msg)
|
||||||
|
} else if result.IsCompleted() {
|
||||||
|
fmt.Println("Record validation completed!")
|
||||||
|
if result.Data != nil {
|
||||||
|
fmt.Printf("Record: %+v\n", result.Data)
|
||||||
|
}
|
||||||
|
} else if result.IsFailed() {
|
||||||
|
fmt.Printf("Record validation failed: %s\n", result.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.7 记录验证(同步)
|
||||||
|
```go
|
||||||
|
// 同步验证(阻塞直到完成)
|
||||||
|
finalRecordResult, err := queryClient.ValidateRecordSync(
|
||||||
|
ctx,
|
||||||
|
recordValidationReq,
|
||||||
|
func(progress *model.RecordValidationResult) {
|
||||||
|
// 可选的进度回调
|
||||||
|
fmt.Printf("Progress: %s\n", progress.Progress)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalRecordResult.IsCompleted() {
|
||||||
|
fmt.Println("Record validation successful!")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Record validation failed: %s\n", finalRecordResult.Msg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.8 获取底层 gRPC 客户端
|
||||||
|
```go
|
||||||
|
// 高级用户可以直接访问 gRPC 客户端进行自定义操作
|
||||||
|
|
||||||
|
// 获取 Operation 服务客户端
|
||||||
|
opGrpcClient := queryClient.GetLowLevelOperationClient()
|
||||||
|
|
||||||
|
// 获取 Record 服务客户端
|
||||||
|
recGrpcClient := queryClient.GetLowLevelRecordClient()
|
||||||
|
|
||||||
|
// 注意:多服务器模式下,每次调用会返回轮询的下一个客户端
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Subscriber 使用(消息订阅)
|
||||||
|
|
||||||
|
> **注意**:通常业务代码不需要直接使用 Subscriber,除非需要原始的 Watermill 消息处理。
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/model"
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/bytedance/sonic"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 创建 Logger
|
||||||
|
myLogger := logger.NewLogger(logr.Discard())
|
||||||
|
|
||||||
|
// 创建订阅者
|
||||||
|
sub, err := adapter.NewSubscriber(
|
||||||
|
adapter.SubscriberConfig{
|
||||||
|
URL: "pulsar://localhost:6650",
|
||||||
|
SubscriberType: pulsar.KeyShared, // 必须使用 KeyShared 模式
|
||||||
|
},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
// 订阅消息(context 必须携带 key 为 "subName" 的 value)
|
||||||
|
ctx := context.WithValue(context.Background(), "subName", "my-subscriber")
|
||||||
|
msgChan, err := sub.Subscribe(ctx, adapter.OperationTopic) // 或者 adapter.RecordTopic
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理消息
|
||||||
|
for msg := range msgChan {
|
||||||
|
var op model.Operation
|
||||||
|
if err := sonic.Unmarshal(msg.Payload, &op); err != nil {
|
||||||
|
myLogger.ErrorContext(ctx, "Invalid Operation message", "error", err)
|
||||||
|
msg.Nack()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理业务逻辑
|
||||||
|
myLogger.InfoContext(ctx, "Received Operation", "key", op.Key())
|
||||||
|
|
||||||
|
// 根据业务成功与否 ack / nack
|
||||||
|
msg.Ack()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. Forwarder 事务性发布(SQL持久化)
|
||||||
|
|
||||||
|
使用 Watermill Forwarder 可以将消息先持久化到 SQL 数据库,然后异步发送到 Pulsar,保证消息的事务性和可靠性。
|
||||||
|
这在需要确保消息不丢失的场景下非常有用。
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"github.com/ThreeDotsLabs/watermill/components/forwarder"
|
||||||
|
"github.com/ThreeDotsLabs/watermill-sql/v3/pkg/sql"
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/highclient"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 0. 创建 Logger
|
||||||
|
myLogger := logger.NewLogger(logr.Discard())
|
||||||
|
|
||||||
|
// 1. 创建 SQL Publisher(用于持久化)
|
||||||
|
db, err := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlPublisher, err := watermillsql.NewPublisher(
|
||||||
|
db,
|
||||||
|
watermillsql.PublisherConfig{
|
||||||
|
SchemaAdapter: watermillsql.DefaultPostgreSQLSchema{},
|
||||||
|
},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 创建 Pulsar Publisher(实际发送)
|
||||||
|
pulsarPublisher, err := adapter.NewPublisher(
|
||||||
|
adapter.PublisherConfig{URL: "pulsar://localhost:6650"},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 创建 Forwarder(SQL -> Pulsar)
|
||||||
|
// 消息先写入 SQL,事务提交后异步转发到 Pulsar
|
||||||
|
fwd, err := forwarder.NewForwarder(sqlPublisher, pulsarPublisher)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 使用 Forwarder 创建客户端
|
||||||
|
// 发布的消息会先存储到 SQL,保证事务性
|
||||||
|
client := highclient.NewClient(fwd, myLogger)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
// 5. 在数据库事务中发布消息
|
||||||
|
tx, _ := db.Begin()
|
||||||
|
// ... 执行业务数据库操作 ...
|
||||||
|
|
||||||
|
// 发布 Operation(会在同一个事务中写入)
|
||||||
|
_ = client.OperationPublish(op)
|
||||||
|
|
||||||
|
// 提交事务(业务数据和消息同时提交)
|
||||||
|
tx.Commit()
|
||||||
|
```
|
||||||
|
|
||||||
|
> **优势**:
|
||||||
|
> - ✅ 消息与业务数据在同一事务中,保证强一致性
|
||||||
|
> - ✅ 即使 Pulsar 暂时不可用,消息也不会丢失
|
||||||
|
> - ✅ Forwarder 会自动重试发送失败的消息
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎨 完整示例
|
||||||
|
|
||||||
|
### 发布 + 查询 + 验证完整流程
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/highclient"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/queryclient"
|
||||||
|
"go.yandata.net/iod/iod/go-trustlog/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 0. 创建 Logger
|
||||||
|
myLogger := logger.NewLogger(logr.Discard())
|
||||||
|
|
||||||
|
// 1. 创建并发送 Operation
|
||||||
|
pub, _ := adapter.NewPublisher(
|
||||||
|
adapter.PublisherConfig{URL: "pulsar://localhost:6650"},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
defer pub.Close()
|
||||||
|
|
||||||
|
// 准备SM2密钥
|
||||||
|
privateKeyHex := []byte("私钥D的十六进制字符串")
|
||||||
|
publicKeyHex := []byte("04 + x坐标 + y坐标的十六进制字符串")
|
||||||
|
envelopeConfig := model.DefaultEnvelopeConfig(privateKeyHex, publicKeyHex)
|
||||||
|
|
||||||
|
client := highclient.NewClient(pub, myLogger, envelopeConfig)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
dataID := model.DataID{
|
||||||
|
DoPrefix: "10.1000",
|
||||||
|
DoRepository: "test-repo",
|
||||||
|
Doid: "10.1000/test-repo/doc001",
|
||||||
|
}
|
||||||
|
|
||||||
|
op, _ := model.NewFullOperation(
|
||||||
|
model.OpSourceDOIP,
|
||||||
|
model.OpTypeCreate,
|
||||||
|
dataID,
|
||||||
|
"admin",
|
||||||
|
[]byte(`{"action":"create"}`),
|
||||||
|
[]byte(`{"status":"success"}`),
|
||||||
|
model.SHA256,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = client.OperationPublish(op)
|
||||||
|
fmt.Printf("Published operation: %s\n", op.Meta.OpID)
|
||||||
|
|
||||||
|
// 等待一段时间让消息被处理
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
// 2. 查询操作列表
|
||||||
|
queryClient, _ := queryclient.NewClient(
|
||||||
|
queryclient.ClientConfig{ServerAddr: "localhost:50051"},
|
||||||
|
myLogger,
|
||||||
|
)
|
||||||
|
defer queryClient.Close()
|
||||||
|
|
||||||
|
listResp, _ := queryClient.ListOperations(ctx, queryclient.ListOperationsRequest{
|
||||||
|
PageSize: 10,
|
||||||
|
DoRepository: "test-repo",
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Printf("Found %d operations\n", listResp.Count)
|
||||||
|
|
||||||
|
// 3. 执行取证验证
|
||||||
|
if len(listResp.Data) > 0 {
|
||||||
|
firstOp := listResp.Data[0]
|
||||||
|
|
||||||
|
validationReq := queryclient.ValidationRequest{
|
||||||
|
Time: firstOp.Meta.Timestamp,
|
||||||
|
OpID: firstOp.Meta.OpID,
|
||||||
|
OpType: string(firstOp.Meta.OpType),
|
||||||
|
DoRepository: firstOp.DataID.DoRepository,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := queryClient.ValidateOperationSync(ctx, validationReq, nil)
|
||||||
|
|
||||||
|
if result.IsCompleted() {
|
||||||
|
fmt.Println("✅ Validation passed!")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("❌ Validation failed: %s\n", result.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📚 操作类型枚举
|
||||||
|
|
||||||
|
### DOIP 操作类型(7种)
|
||||||
|
```go
|
||||||
|
model.OpTypeHello // Hello 握手
|
||||||
|
model.OpTypeRetrieve // 检索资源
|
||||||
|
model.OpTypeCreate // 新建资源
|
||||||
|
model.OpTypeDelete // 删除资源
|
||||||
|
model.OpTypeUpdate // 更新资源
|
||||||
|
model.OpTypeSearch // 搜索资源
|
||||||
|
model.OpTypeListOperations // 列出可用操作
|
||||||
|
```
|
||||||
|
|
||||||
|
### IRP 操作类型(33种)
|
||||||
|
```go
|
||||||
|
// Handle 基础操作
|
||||||
|
model.OpTypeOCReserved, model.OpTypeOCResolution, model.OpTypeOCGetSiteInfo
|
||||||
|
model.OpTypeOCCreateHandle, model.OpTypeOCDeleteHandle, model.OpTypeOCAddValue
|
||||||
|
model.OpTypeOCRemoveValue, model.OpTypeOCModifyValue, model.OpTypeOCListHandle
|
||||||
|
model.OpTypeOCListNA
|
||||||
|
|
||||||
|
// DOID 操作
|
||||||
|
model.OpTypeOCResolutionDOID, model.OpTypeOCCreateDOID, model.OpTypeOCDeleteDOID
|
||||||
|
model.OpTypeOCUpdateDOID, model.OpTypeOCBatchCreateDOID, model.OpTypeOCResolutionDOIDRecursive
|
||||||
|
|
||||||
|
// 用户与仓库
|
||||||
|
model.OpTypeOCGetUsers, model.OpTypeOCGetRepos
|
||||||
|
|
||||||
|
// GRS/IRS 管理
|
||||||
|
model.OpTypeOCVerifyIRS, model.OpTypeOCResolveGRS, model.OpTypeOCCreateOrgGRS
|
||||||
|
model.OpTypeOCUpdateOrgGRS, model.OpTypeOCDeleteOrgGRS, model.OpTypeOCSyncOrgIRSParent
|
||||||
|
model.OpTypeOCUpdateOrgIRSParent, model.OpTypeOCDeleteOrgIRSParent
|
||||||
|
|
||||||
|
// 安全与会话
|
||||||
|
model.OpTypeOCChallengeResponse, model.OpTypeOCVerifyChallenge, model.OpTypeOCSessionSetup
|
||||||
|
model.OpTypeOCSessionTerminate, model.OpTypeOCSessionExchangeKey, model.OpTypeOCVerifyRouter
|
||||||
|
model.OpTypeOCQueryRouter
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⚠️ 注意事项
|
||||||
|
|
||||||
|
1. **私有仓库配置**
|
||||||
|
必须先配置 Git SSH 映射和 GOPRIVATE 环境变量,否则无法正常安装 SDK。
|
||||||
|
|
||||||
|
2. **日志接口**
|
||||||
|
SDK 使用 [logr](https://github.com/go-logr/logr) 作为日志接口。你需要:
|
||||||
|
- 创建一个 `logr.Logger` 实例(可以使用 zap、logrus 等实现)
|
||||||
|
- 通过 `logger.NewLogger(logrLogger)` 包装成 SDK 的 Logger 接口
|
||||||
|
- 在生产环境建议使用 `zapr` 或 `logrusr` 等实现,测试环境可以使用 `logr.Discard()`
|
||||||
|
|
||||||
|
3. **HighClient 方法名**
|
||||||
|
- 发送 Operation 使用 `client.OperationPublish(op)`,参数为指针类型 `*model.Operation`
|
||||||
|
- 发送 Record 使用 `client.RecordPublish(record)`,参数为指针类型 `*model.Record`
|
||||||
|
|
||||||
|
4. **固定主题**
|
||||||
|
- Operation 主题:`persistent://public/default/operation`
|
||||||
|
- Record 主题:`persistent://public/default/record`
|
||||||
|
|
||||||
|
5. **KeyShared 消费模式**
|
||||||
|
由于 Trustlog 使用 Key Shared 消费模式,其他订阅者必须选择 KeyShared 并避免消费者重名。
|
||||||
|
|
||||||
|
6. **ack/nack 必须处理**
|
||||||
|
确保订阅方根据业务逻辑确认或拒绝消息。
|
||||||
|
|
||||||
|
7. **时间戳处理**
|
||||||
|
`NewFullOperation()` 接受 `time.Time` 类型的时间戳参数。
|
||||||
|
|
||||||
|
8. **统一连接池**
|
||||||
|
QueryClient 使用单一连接池同时支持 Operation 和 Record 两种服务,共享 gRPC 连接资源,提高资源利用率。
|
||||||
|
|
||||||
|
9. **负载均衡**
|
||||||
|
支持多服务器轮询负载均衡,自动分发请求到不同服务器,连接在两种服务间共享。
|
||||||
|
|
||||||
|
10. **流式验证**
|
||||||
|
取证验证(Operation 和 Record)都支持流式和同步两种模式,流式模式可实时获取进度。
|
||||||
|
|
||||||
|
11. **事务性发布**
|
||||||
|
使用 Watermill Forwarder 可以将消息持久化到 SQL,与业务数据在同一事务中提交,保证强一致性。
|
||||||
|
|
||||||
|
12. **Record 支持**
|
||||||
|
除了 Operation,SDK 现在也支持 Record 类型的发布、查询和验证,两种服务使用同一个 QueryClient。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔄 架构图
|
||||||
|
|
||||||
|
### 直接发布架构
|
||||||
|
```
|
||||||
|
[业务服务]
|
||||||
|
↓
|
||||||
|
[HighClient.Publish()]
|
||||||
|
↓
|
||||||
|
[Pulsar Publisher] --(Operation JSON)--> [Pulsar Topic]
|
||||||
|
↓
|
||||||
|
[Subscriber]
|
||||||
|
↓
|
||||||
|
[其他服务]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 事务性发布架构(使用 Forwarder)
|
||||||
|
```
|
||||||
|
[业务服务 + DB事务]
|
||||||
|
↓
|
||||||
|
[HighClient.Publish()]
|
||||||
|
↓
|
||||||
|
[SQL Publisher] --写入--> [PostgreSQL/MySQL]
|
||||||
|
↓ ↓
|
||||||
|
[Forwarder 后台轮询] |
|
||||||
|
↓ |
|
||||||
|
[读取未发送消息] <--------------┘
|
||||||
|
↓
|
||||||
|
[Pulsar Publisher] --(Operation JSON)--> [Pulsar Topic]
|
||||||
|
↓ ↓
|
||||||
|
[标记为已发送] [Subscriber]
|
||||||
|
↓
|
||||||
|
[其他服务]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 查询架构(统一连接池)
|
||||||
|
```
|
||||||
|
[业务服务]
|
||||||
|
↓
|
||||||
|
[QueryClient - 单一连接池]
|
||||||
|
├─ Operation 服务客户端 ─┐
|
||||||
|
└─ Record 服务客户端 ────┤
|
||||||
|
↓ (共享 gRPC 连接,轮询负载均衡)
|
||||||
|
[Server 1] ─┐
|
||||||
|
[Server 2] ─┼─ 多服务器
|
||||||
|
[Server 3] ─┘
|
||||||
|
↓
|
||||||
|
[存储层]
|
||||||
|
|
||||||
|
优势:
|
||||||
|
- 单一连接池,资源高效利用
|
||||||
|
- Operation 和 Record 服务共享连接
|
||||||
|
- 自动负载均衡,请求分发到不同服务器
|
||||||
|
- 减少连接数,降低服务器压力
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
205
api/adapter/TCP_QUICK_START.md
Normal file
205
api/adapter/TCP_QUICK_START.md
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# TCP 适配器快速开始指南
|
||||||
|
|
||||||
|
## 简介
|
||||||
|
|
||||||
|
TCP 适配器提供了一个无需 Pulsar 的 Watermill 消息发布/订阅实现,适用于内网直连场景。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 启动消费端(Subscriber)
|
||||||
|
|
||||||
|
消费端作为 TCP 服务器,监听指定端口。
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 使用 NopLogger 或自定义 logger
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建 Subscriber
|
||||||
|
config := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:9090",
|
||||||
|
}
|
||||||
|
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(config, log)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
// 订阅 topic
|
||||||
|
messages, err := subscriber.Subscribe(context.Background(), "my-topic")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理消息
|
||||||
|
for msg := range messages {
|
||||||
|
log.Println("收到消息:", string(msg.Payload))
|
||||||
|
msg.Ack() // 确认消息
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 启动生产端(Publisher)
|
||||||
|
|
||||||
|
生产端作为 TCP 客户端,连接到消费端。
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:9090",
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
AckTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
msg := message.NewMessage("msg-001", []byte("Hello, World!"))
|
||||||
|
|
||||||
|
err = publisher.Publish("my-topic", msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("消息发送成功")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 特性演示
|
||||||
|
|
||||||
|
### 并发发送多条消息
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 准备 10 条消息
|
||||||
|
messages := make([]*message.Message, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
payload := []byte(fmt.Sprintf("Message #%d", i))
|
||||||
|
messages[i] = message.NewMessage(fmt.Sprintf("msg-%d", i), payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 并发发送,Publisher 会等待所有 ACK
|
||||||
|
err := publisher.Publish("my-topic", messages...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("所有消息发送成功")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 错误处理和 NACK
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 在消费端
|
||||||
|
for msg := range messages {
|
||||||
|
// 处理消息
|
||||||
|
if err := processMessage(msg); err != nil {
|
||||||
|
log.Println("处理失败:", err)
|
||||||
|
msg.Nack() // 拒绝消息
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg.Ack() // 确认消息
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置参数
|
||||||
|
|
||||||
|
### TCPPublisherConfig
|
||||||
|
|
||||||
|
```go
|
||||||
|
type TCPPublisherConfig struct {
|
||||||
|
ServerAddr string // 必填: TCP 服务器地址,如 "127.0.0.1:9090"
|
||||||
|
ConnectTimeout time.Duration // 连接超时,默认 10s
|
||||||
|
AckTimeout time.Duration // ACK 超时,默认 30s
|
||||||
|
MaxRetries int // 最大重试次数,默认 3
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### TCPSubscriberConfig
|
||||||
|
|
||||||
|
```go
|
||||||
|
type TCPSubscriberConfig struct {
|
||||||
|
ListenAddr string // 必填: 监听地址,如 "127.0.0.1:9090"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行完整示例
|
||||||
|
cd trustlog-sdk/examples
|
||||||
|
go run tcp_example.go
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能特点
|
||||||
|
|
||||||
|
- ✅ **低延迟**: 直接 TCP 连接,无中间件开销
|
||||||
|
- ✅ **高并发**: 支持并发发送多条消息
|
||||||
|
- ✅ **可靠性**: 每条消息都需要 ACK 确认
|
||||||
|
- ⚠️ **无持久化**: 消息仅在内存中传递
|
||||||
|
|
||||||
|
## 适用场景
|
||||||
|
|
||||||
|
✅ **适合:**
|
||||||
|
- 内网服务间直接通信
|
||||||
|
- 开发和测试环境
|
||||||
|
- 无需消息持久化的场景
|
||||||
|
- 低延迟要求的场景
|
||||||
|
|
||||||
|
❌ **不适合:**
|
||||||
|
- 需要消息持久化
|
||||||
|
- 需要高可用和故障恢复
|
||||||
|
- 公网通信(需要加密)
|
||||||
|
- 需要复杂的路由和负载均衡
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### Q: 如何处理连接断开?
|
||||||
|
|
||||||
|
A: 当前版本连接断开后需要重新创建 Publisher。未来版本将支持自动重连。
|
||||||
|
|
||||||
|
### Q: 消息会丢失吗?
|
||||||
|
|
||||||
|
A: TCP 适配器不提供持久化,连接断开或服务重启会导致未确认的消息丢失。
|
||||||
|
|
||||||
|
### Q: 如何实现多个消费者?
|
||||||
|
|
||||||
|
A: 当前版本将消息发送到第一个订阅者。如需负载均衡,需要在应用层实现。
|
||||||
|
|
||||||
|
### Q: 支持 TLS 加密吗?
|
||||||
|
|
||||||
|
A: 当前版本不支持 TLS。未来版本将添加 TLS/mTLS 支持。
|
||||||
|
|
||||||
|
## 下一步
|
||||||
|
|
||||||
|
- 查看 [完整文档](TCP_ADAPTER_README.md)
|
||||||
|
- 运行 [测试用例](tcp_integration_test.go)
|
||||||
|
- 查看 [示例代码](../../examples/tcp_example.go)
|
||||||
|
|
||||||
608
api/adapter/mocks/pulsar_mock.go
Normal file
608
api/adapter/mocks/pulsar_mock.go
Normal file
@@ -0,0 +1,608 @@
|
|||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockPulsarClient is a mock implementation of pulsar.Client.
|
||||||
|
type MockPulsarClient struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
producers map[string]*MockProducer
|
||||||
|
consumers map[string]*MockConsumer
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockPulsarClient creates a new mock Pulsar client.
|
||||||
|
func NewMockPulsarClient() *MockPulsarClient {
|
||||||
|
return &MockPulsarClient{
|
||||||
|
producers: make(map[string]*MockProducer),
|
||||||
|
consumers: make(map[string]*MockConsumer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProducer creates a mock producer.
|
||||||
|
func (m *MockPulsarClient) CreateProducer(options pulsar.ProducerOptions) (pulsar.Producer, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return nil, errors.New("client is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.producers == nil {
|
||||||
|
m.producers = make(map[string]*MockProducer)
|
||||||
|
}
|
||||||
|
|
||||||
|
producer := NewMockProducer(options.Topic)
|
||||||
|
m.producers[options.Topic] = producer
|
||||||
|
return producer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe creates a mock consumer.
|
||||||
|
func (m *MockPulsarClient) Subscribe(options pulsar.ConsumerOptions) (pulsar.Consumer, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return nil, errors.New("client is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.consumers == nil {
|
||||||
|
m.consumers = make(map[string]*MockConsumer)
|
||||||
|
}
|
||||||
|
|
||||||
|
consumer := NewMockConsumer(options.Topic, options.Name)
|
||||||
|
m.consumers[options.Name] = consumer
|
||||||
|
return consumer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateReader is not implemented.
|
||||||
|
func (m *MockPulsarClient) CreateReader(options pulsar.ReaderOptions) (pulsar.Reader, error) {
|
||||||
|
return nil, errors.New("CreateReader not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTableView is not implemented.
|
||||||
|
func (m *MockPulsarClient) CreateTableView(options pulsar.TableViewOptions) (pulsar.TableView, error) {
|
||||||
|
return nil, errors.New("CreateTableView not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTransaction creates a new transaction.
|
||||||
|
func (m *MockPulsarClient) NewTransaction(timeout time.Duration) (pulsar.Transaction, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TopicPartitions returns the partitions for a topic.
|
||||||
|
func (m *MockPulsarClient) TopicPartitions(topic string) ([]string, error) {
|
||||||
|
return []string{topic}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the mock client.
|
||||||
|
func (m *MockPulsarClient) Close() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.closed = true
|
||||||
|
for _, producer := range m.producers {
|
||||||
|
producer.Close()
|
||||||
|
}
|
||||||
|
for _, consumer := range m.consumers {
|
||||||
|
consumer.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProducer returns a producer by topic (for testing).
|
||||||
|
func (m *MockPulsarClient) GetProducer(topic string) *MockProducer {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return m.producers[topic]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConsumer returns a consumer by name (for testing).
|
||||||
|
func (m *MockPulsarClient) GetConsumer(name string) *MockConsumer {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return m.consumers[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockProducer is a mock implementation of pulsar.Producer.
|
||||||
|
type MockProducer struct {
|
||||||
|
topic string
|
||||||
|
name string
|
||||||
|
messages []*pulsar.ProducerMessage
|
||||||
|
mu sync.RWMutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockProducer creates a new mock producer.
|
||||||
|
func NewMockProducer(topic string) *MockProducer {
|
||||||
|
return &MockProducer{
|
||||||
|
topic: topic,
|
||||||
|
name: "mock-producer",
|
||||||
|
messages: make([]*pulsar.ProducerMessage, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topic returns the topic name.
|
||||||
|
func (m *MockProducer) Topic() string {
|
||||||
|
return m.topic
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the producer name.
|
||||||
|
func (m *MockProducer) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message.
|
||||||
|
func (m *MockProducer) Send(ctx context.Context, msg *pulsar.ProducerMessage) (pulsar.MessageID, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return nil, errors.New("producer is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.messages = append(m.messages, msg)
|
||||||
|
return &MockMessageID{id: len(m.messages)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendAsync sends a message asynchronously.
|
||||||
|
func (m *MockProducer) SendAsync(
|
||||||
|
ctx context.Context,
|
||||||
|
msg *pulsar.ProducerMessage,
|
||||||
|
callback func(pulsar.MessageID, *pulsar.ProducerMessage, error),
|
||||||
|
) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
callback(nil, msg, errors.New("producer is closed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.messages = append(m.messages, msg)
|
||||||
|
callback(&MockMessageID{id: len(m.messages)}, msg, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastSequenceID returns the last sequence ID.
|
||||||
|
func (m *MockProducer) LastSequenceID() int64 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush flushes pending messages.
|
||||||
|
func (m *MockProducer) Flush() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlushWithCtx flushes pending messages with context.
|
||||||
|
func (m *MockProducer) FlushWithCtx(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the producer.
|
||||||
|
func (m *MockProducer) Close() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.closed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages returns all sent messages (for testing).
|
||||||
|
func (m *MockProducer) GetMessages() []*pulsar.ProducerMessage {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]*pulsar.ProducerMessage, len(m.messages))
|
||||||
|
copy(result, m.messages)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockConsumer is a mock implementation of pulsar.Consumer.
|
||||||
|
type MockConsumer struct {
|
||||||
|
topic string
|
||||||
|
name string
|
||||||
|
messageChan chan pulsar.ConsumerMessage
|
||||||
|
mu sync.RWMutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defaultMessageChannelSize 定义消息通道的默认缓冲大小.
|
||||||
|
defaultMessageChannelSize = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewMockConsumer creates a new mock consumer.
|
||||||
|
func NewMockConsumer(topic, name string) *MockConsumer {
|
||||||
|
return &MockConsumer{
|
||||||
|
topic: topic,
|
||||||
|
name: name,
|
||||||
|
messageChan: make(chan pulsar.ConsumerMessage, defaultMessageChannelSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscription returns the subscription name.
|
||||||
|
func (m *MockConsumer) Subscription() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topic returns the topic name.
|
||||||
|
func (m *MockConsumer) Topic() string {
|
||||||
|
return m.topic
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chan returns the message channel.
|
||||||
|
func (m *MockConsumer) Chan() <-chan pulsar.ConsumerMessage {
|
||||||
|
return m.messageChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ack acknowledges a message.
|
||||||
|
func (m *MockConsumer) Ack(msg pulsar.Message) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nack negatively acknowledges a message.
|
||||||
|
func (m *MockConsumer) Nack(msg pulsar.Message) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
// Mock implementation: 实际不做任何操作
|
||||||
|
_ = msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// NackID negatively acknowledges a message by ID.
|
||||||
|
func (m *MockConsumer) NackID(msgID pulsar.MessageID) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
// Mock implementation: 实际不做任何操作
|
||||||
|
_ = msgID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe unsubscribes the consumer.
|
||||||
|
func (m *MockConsumer) Unsubscribe() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsubscribeForce forcefully unsubscribes the consumer.
|
||||||
|
func (m *MockConsumer) UnsubscribeForce() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives a single message.
|
||||||
|
func (m *MockConsumer) Receive(ctx context.Context) (pulsar.Message, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return nil, errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-m.messageChan:
|
||||||
|
return msg.Message, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckCumulative acknowledges all messages up to and including the provided message.
|
||||||
|
func (m *MockConsumer) AckCumulative(msg pulsar.Message) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckID acknowledges a message by ID.
|
||||||
|
func (m *MockConsumer) AckID(msgID pulsar.MessageID) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckIDCumulative acknowledges all messages up to and including the provided message ID.
|
||||||
|
func (m *MockConsumer) AckIDCumulative(msgID pulsar.MessageID) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckIDList acknowledges a list of message IDs.
|
||||||
|
func (m *MockConsumer) AckIDList(msgIDs []pulsar.MessageID) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckWithTxn acknowledges a message with transaction.
|
||||||
|
func (m *MockConsumer) AckWithTxn(msg pulsar.Message, txn pulsar.Transaction) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastMessageIDs returns the last message IDs.
|
||||||
|
func (m *MockConsumer) GetLastMessageIDs() ([]pulsar.TopicMessageID, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return nil, errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return []pulsar.TopicMessageID{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReconsumeLater reconsumes a message later with delay.
|
||||||
|
func (m *MockConsumer) ReconsumeLater(msg pulsar.Message, delay time.Duration) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
// Mock implementation: 实际不做任何操作
|
||||||
|
_, _ = msg, delay
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReconsumeLaterWithCustomProperties reconsumes a message later with custom properties.
|
||||||
|
func (m *MockConsumer) ReconsumeLaterWithCustomProperties(
|
||||||
|
msg pulsar.Message,
|
||||||
|
customProperties map[string]string,
|
||||||
|
delay time.Duration,
|
||||||
|
) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
// Mock implementation: 实际不做任何操作
|
||||||
|
_, _, _ = msg, customProperties, delay
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek seeks to a message ID.
|
||||||
|
func (m *MockConsumer) Seek(msgID pulsar.MessageID) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SeekByTime seeks to a time.
|
||||||
|
func (m *MockConsumer) SeekByTime(t time.Time) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the consumer name.
|
||||||
|
func (m *MockConsumer) Name() string {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the consumer.
|
||||||
|
func (m *MockConsumer) Close() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.closed = true
|
||||||
|
close(m.messageChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a message to the consumer channel (for testing).
|
||||||
|
func (m *MockConsumer) SendMessage(msg pulsar.ConsumerMessage) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if m.closed {
|
||||||
|
return errors.New("consumer is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case m.messageChan <- msg:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("channel full")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockMessageID is a mock implementation of pulsar.MessageID.
|
||||||
|
type MockMessageID struct {
|
||||||
|
id int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize serializes the message ID.
|
||||||
|
func (m *MockMessageID) Serialize() []byte {
|
||||||
|
return []byte{byte(m.id)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchIdx returns the batch index.
|
||||||
|
func (m *MockMessageID) BatchIdx() int32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the batch size.
|
||||||
|
func (m *MockMessageID) BatchSize() int32 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the message ID.
|
||||||
|
func (m *MockMessageID) String() string {
|
||||||
|
return fmt.Sprintf("mock-message-id-%d", m.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EntryID returns the entry ID.
|
||||||
|
func (m *MockMessageID) EntryID() int64 {
|
||||||
|
return int64(m.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LedgerID returns the ledger ID.
|
||||||
|
func (m *MockMessageID) LedgerID() int64 {
|
||||||
|
return int64(m.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PartitionIdx returns the partition index.
|
||||||
|
func (m *MockMessageID) PartitionIdx() int32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockMessage is a mock implementation of pulsar.Message.
|
||||||
|
type MockMessage struct {
|
||||||
|
key string
|
||||||
|
payload []byte
|
||||||
|
id pulsar.MessageID
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockMessage creates a new mock message.
|
||||||
|
func NewMockMessage(key string, payload []byte) *MockMessage {
|
||||||
|
return &MockMessage{
|
||||||
|
key: key,
|
||||||
|
payload: payload,
|
||||||
|
id: &MockMessageID{id: 1},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topic returns the topic name.
|
||||||
|
func (m *MockMessage) Topic() string {
|
||||||
|
return "mock-topic"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Properties returns message properties.
|
||||||
|
func (m *MockMessage) Properties() map[string]string {
|
||||||
|
return make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Payload returns the message payload.
|
||||||
|
func (m *MockMessage) Payload() []byte {
|
||||||
|
return m.payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the message ID.
|
||||||
|
func (m *MockMessage) ID() pulsar.MessageID {
|
||||||
|
return m.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishTime returns the publish time.
|
||||||
|
func (m *MockMessage) PublishTime() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventTime returns the event time.
|
||||||
|
func (m *MockMessage) EventTime() time.Time {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key returns the message key.
|
||||||
|
func (m *MockMessage) Key() string {
|
||||||
|
return m.key
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderingKey returns the ordering key.
|
||||||
|
func (m *MockMessage) OrderingKey() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedeliveryCount returns the redelivery count.
|
||||||
|
func (m *MockMessage) RedeliveryCount() uint32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReplicated returns whether the message is replicated.
|
||||||
|
func (m *MockMessage) IsReplicated() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReplicatedFrom returns the replication source.
|
||||||
|
func (m *MockMessage) GetReplicatedFrom() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSchemaValue returns the schema value.
|
||||||
|
func (m *MockMessage) GetSchemaValue(v interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEncryptionContext returns the encryption context.
|
||||||
|
func (m *MockMessage) GetEncryptionContext() *pulsar.EncryptionContext {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Index returns the message index.
|
||||||
|
func (m *MockMessage) Index() *uint64 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BrokerPublishTime returns the broker publish time.
|
||||||
|
func (m *MockMessage) BrokerPublishTime() *time.Time {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProducerName returns the producer name.
|
||||||
|
func (m *MockMessage) ProducerName() string {
|
||||||
|
return "mock-producer"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SchemaVersion returns the schema version.
|
||||||
|
func (m *MockMessage) SchemaVersion() []byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplicatedFrom returns the replication source.
|
||||||
|
func (m *MockMessage) ReplicatedFrom() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockConsumerMessage creates a new mock consumer message.
|
||||||
|
func NewMockConsumerMessage(key string, payload []byte) pulsar.ConsumerMessage {
|
||||||
|
return pulsar.ConsumerMessage{
|
||||||
|
Message: NewMockMessage(key, payload),
|
||||||
|
Consumer: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
120
api/adapter/publisher.go
Normal file
120
api/adapter/publisher.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
logger2 "go.yandata.net/iod/iod/trustlog-sdk/internal/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OperationTopic = "persistent://public/default/operation"
|
||||||
|
RecordTopic = "persistent://public/default/record"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PublisherConfig is the configuration to create a publisher.
|
||||||
|
type PublisherConfig struct {
|
||||||
|
// URL is the Pulsar URL.
|
||||||
|
URL string
|
||||||
|
// TLSTrustCertsFilePath is the path to the CA certificate file for verifying the server certificate.
|
||||||
|
// If empty, TLS verification will be disabled.
|
||||||
|
TLSTrustCertsFilePath string
|
||||||
|
// TLSCertificateFilePath is the path to the client certificate file for mTLS authentication.
|
||||||
|
// If empty, mTLS authentication will be disabled.
|
||||||
|
TLSCertificateFilePath string
|
||||||
|
// TLSKeyFilePath is the path to the client private key file for mTLS authentication.
|
||||||
|
// If empty, mTLS authentication will be disabled.
|
||||||
|
TLSKeyFilePath string
|
||||||
|
// TLSAllowInsecureConnection allows insecure TLS connections (not recommended for production).
|
||||||
|
TLSAllowInsecureConnection bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publisher provides the pulsar implementation for watermill publish operations.
|
||||||
|
type Publisher struct {
|
||||||
|
conn pulsar.Client
|
||||||
|
logger logger.Logger
|
||||||
|
pubs map[string]pulsar.Producer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPublisher creates a new Publisher.
|
||||||
|
func NewPublisher(config PublisherConfig, adapter logger.Logger) (*Publisher, error) {
|
||||||
|
clientOptions := pulsar.ClientOptions{
|
||||||
|
URL: config.URL,
|
||||||
|
Logger: logger2.NewPulsarLoggerAdapter(adapter),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure TLS/mTLS
|
||||||
|
if err := configureTLSForClient(&clientOptions, config, adapter); err != nil {
|
||||||
|
return nil, errors.Join(err, errors.New("failed to configure TLS"))
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pulsar.NewClient(clientOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Join(err, errors.New("cannot connect to pulsar"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewPublisherWithPulsarClient(conn, adapter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPublisherWithPulsarClient creates a new Publisher with the provided pulsar connection.
|
||||||
|
func NewPublisherWithPulsarClient(conn pulsar.Client, logger logger.Logger) (*Publisher, error) {
|
||||||
|
return &Publisher{
|
||||||
|
conn: conn,
|
||||||
|
pubs: make(map[string]pulsar.Producer),
|
||||||
|
logger: logger,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes message to Pulsar.
|
||||||
|
//
|
||||||
|
// Publish will not return until an ack has been received from Pulsar.
|
||||||
|
// When one of messages delivery fails - function is interrupted.
|
||||||
|
func (p *Publisher) Publish(topic string, messages ...*message.Message) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
producer, found := p.pubs[topic]
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
pr, err := p.conn.CreateProducer(pulsar.ProducerOptions{Topic: topic})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
producer = pr
|
||||||
|
p.pubs[topic] = producer
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
// 跳过 nil 消息
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.DebugContext(ctx, "Sending message", "key", msg.UUID, "topic", topic)
|
||||||
|
_, err := producer.Send(ctx, &pulsar.ProducerMessage{
|
||||||
|
Key: msg.UUID,
|
||||||
|
Payload: msg.Payload,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the publisher and the underlying connection.
|
||||||
|
func (p *Publisher) Close() error {
|
||||||
|
for _, pub := range p.pubs {
|
||||||
|
pub.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.conn.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
212
api/adapter/publisher_test.go
Normal file
212
api/adapter/publisher_test.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
package adapter_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPublisherWithPulsarClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, pub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := message.NewMessage("test-uuid", []byte("test payload"))
|
||||||
|
err = pub.Publish("test-topic", msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify message was sent
|
||||||
|
producer := mockClient.GetProducer("test-topic")
|
||||||
|
require.NotNil(t, producer)
|
||||||
|
messages := producer.GetMessages()
|
||||||
|
require.Len(t, messages, 1)
|
||||||
|
assert.Equal(t, "test-uuid", messages[0].Key)
|
||||||
|
assert.Equal(t, []byte("test payload"), messages[0].Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_MultipleMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||||
|
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||||
|
err = pub.Publish("test-topic", msg1, msg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
producer := mockClient.GetProducer("test-topic")
|
||||||
|
require.NotNil(t, producer)
|
||||||
|
messages := producer.GetMessages()
|
||||||
|
require.Len(t, messages, 2)
|
||||||
|
assert.Equal(t, "uuid-1", messages[0].Key)
|
||||||
|
assert.Equal(t, "uuid-2", messages[1].Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_MultipleTopics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||||
|
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||||
|
|
||||||
|
err = pub.Publish("topic-1", msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Publish("topic-2", msg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
producer1 := mockClient.GetProducer("topic-1")
|
||||||
|
require.NotNil(t, producer1)
|
||||||
|
messages1 := producer1.GetMessages()
|
||||||
|
require.Len(t, messages1, 1)
|
||||||
|
|
||||||
|
producer2 := mockClient.GetProducer("topic-2")
|
||||||
|
require.NotNil(t, producer2)
|
||||||
|
messages2 := producer2.GetMessages()
|
||||||
|
require.Len(t, messages2, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Close(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Close_AfterPublish(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := message.NewMessage("test-uuid", []byte("test payload"))
|
||||||
|
err = pub.Publish("test-topic", msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_ReuseProducer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||||
|
err = pub.Publish("test-topic", msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||||
|
err = pub.Publish("test-topic", msg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
producer := mockClient.GetProducer("test-topic")
|
||||||
|
require.NotNil(t, producer)
|
||||||
|
messages := producer.GetMessages()
|
||||||
|
require.Len(t, messages, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_EmptyTopic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := message.NewMessage("uuid", []byte("payload"))
|
||||||
|
err = pub.Publish("", msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_NilMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Publish with nil message - should handle gracefully
|
||||||
|
err = pub.Publish("test-topic", nil)
|
||||||
|
// May succeed or fail depending on implementation
|
||||||
|
_ = err
|
||||||
|
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_AfterClose(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg := message.NewMessage("uuid", []byte("payload"))
|
||||||
|
err = pub.Publish("test-topic", msg)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPublisher_InvalidURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := adapter.PublisherConfig{
|
||||||
|
URL: "invalid-url",
|
||||||
|
}
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
_, err := adapter.NewPublisher(config, log)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "cannot connect")
|
||||||
|
}
|
||||||
274
api/adapter/subscriber.go
Normal file
274
api/adapter/subscriber.go
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill"
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
logger2 "go.yandata.net/iod/iod/trustlog-sdk/internal/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SubNameKey contextKey = "subName"
|
||||||
|
ReceiverQueueSizeKey contextKey = "receiverQueueSize"
|
||||||
|
IndexKey contextKey = "index"
|
||||||
|
|
||||||
|
ReceiverQueueSizeDefault = 1000
|
||||||
|
SubNameDefault = "subName"
|
||||||
|
TimeOutDefault = time.Second * 10
|
||||||
|
defaultMessageChannelSize = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
var _ message.Subscriber = &Subscriber{}
|
||||||
|
|
||||||
|
// SubscriberConfig is the configuration to create a subscriber.
|
||||||
|
type SubscriberConfig struct {
|
||||||
|
// URL is the URL to the broker
|
||||||
|
URL string
|
||||||
|
// SubscriberName is the name of the subscription.
|
||||||
|
SubscriberName string
|
||||||
|
// SubscriberType is the type of the subscription.
|
||||||
|
SubscriberType pulsar.SubscriptionType
|
||||||
|
// TLSTrustCertsFilePath is the path to the CA certificate file for verifying the server certificate.
|
||||||
|
// If empty, TLS verification will be disabled.
|
||||||
|
TLSTrustCertsFilePath string
|
||||||
|
// TLSCertificateFilePath is the path to the client certificate file for mTLS authentication.
|
||||||
|
// If empty, mTLS authentication will be disabled.
|
||||||
|
TLSCertificateFilePath string
|
||||||
|
// TLSKeyFilePath is the path to the client private key file for mTLS authentication.
|
||||||
|
// If empty, mTLS authentication will be disabled.
|
||||||
|
TLSKeyFilePath string
|
||||||
|
// TLSAllowInsecureConnection allows insecure TLS connections (not recommended for production).
|
||||||
|
TLSAllowInsecureConnection bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscriber provides the pulsar implementation for watermill subscribe operations.
|
||||||
|
type Subscriber struct {
|
||||||
|
conn pulsar.Client
|
||||||
|
logger logger.Logger
|
||||||
|
|
||||||
|
subsLock sync.RWMutex
|
||||||
|
// Change to map with composite key: topic + subscriptionName + subName
|
||||||
|
subs map[string]pulsar.Consumer
|
||||||
|
closed bool
|
||||||
|
closing chan struct{}
|
||||||
|
SubscribersCount int
|
||||||
|
clientID string
|
||||||
|
|
||||||
|
config SubscriberConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSubscriber creates a new Subscriber.
|
||||||
|
func NewSubscriber(config SubscriberConfig, adapter logger.Logger) (*Subscriber, error) {
|
||||||
|
clientOptions := pulsar.ClientOptions{
|
||||||
|
URL: config.URL,
|
||||||
|
Logger: logger2.NewPulsarLoggerAdapter(adapter),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure TLS/mTLS
|
||||||
|
if err := configureTLSForClient(&clientOptions, config, adapter); err != nil {
|
||||||
|
return nil, errors.Join(err, errors.New("failed to configure TLS"))
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pulsar.NewClient(clientOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Join(err, errors.New("cannot connect to Pulsar"))
|
||||||
|
}
|
||||||
|
return NewSubscriberWithPulsarClient(conn, config, adapter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSubscriberWithPulsarClient creates a new Subscriber with the provided pulsar client.
|
||||||
|
func NewSubscriberWithPulsarClient(
|
||||||
|
conn pulsar.Client,
|
||||||
|
config SubscriberConfig,
|
||||||
|
logger logger.Logger,
|
||||||
|
) (*Subscriber, error) {
|
||||||
|
return &Subscriber{
|
||||||
|
conn: conn,
|
||||||
|
logger: logger,
|
||||||
|
closing: make(chan struct{}),
|
||||||
|
clientID: watermill.NewULID(),
|
||||||
|
subs: make(map[string]pulsar.Consumer),
|
||||||
|
config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes messages from Pulsar.
|
||||||
|
func (s *Subscriber) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
|
||||||
|
output := make(chan *message.Message)
|
||||||
|
|
||||||
|
s.subsLock.Lock()
|
||||||
|
|
||||||
|
subName, ok := ctx.Value(SubNameKey).(string)
|
||||||
|
if !ok {
|
||||||
|
subName = SubNameDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
index, ok := ctx.Value(IndexKey).(int)
|
||||||
|
if !ok {
|
||||||
|
index = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
receiverQueueSize, ok := ctx.Value(ReceiverQueueSizeKey).(int)
|
||||||
|
if !ok {
|
||||||
|
receiverQueueSize = ReceiverQueueSizeDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
subscriptionName := fmt.Sprintf("%s-%s", topic, s.clientID)
|
||||||
|
if s.config.SubscriberName != "" {
|
||||||
|
subscriptionName = s.config.SubscriberName
|
||||||
|
}
|
||||||
|
|
||||||
|
sn := fmt.Sprintf("%s_%s", subscriptionName, subName)
|
||||||
|
n := fmt.Sprintf("%s_%d", sn, index)
|
||||||
|
|
||||||
|
sub, found := s.subs[n]
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
subscribeCtx, cancel := context.WithTimeout(ctx, TimeOutDefault)
|
||||||
|
defer cancel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
var sb pulsar.Consumer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
sb, err = s.conn.Subscribe(pulsar.ConsumerOptions{
|
||||||
|
Topic: topic,
|
||||||
|
Name: n,
|
||||||
|
SubscriptionName: sn,
|
||||||
|
Type: s.config.SubscriberType,
|
||||||
|
MessageChannel: make(chan pulsar.ConsumerMessage, defaultMessageChannelSize),
|
||||||
|
ReceiverQueueSize: receiverQueueSize,
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-subscribeCtx.Done():
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
return nil, fmt.Errorf("subscription timeout: %w", subscribeCtx.Err())
|
||||||
|
case <-done:
|
||||||
|
if err != nil {
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
return nil, fmt.Errorf("subscription failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.subs[n] = sb
|
||||||
|
sub = sb
|
||||||
|
}
|
||||||
|
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
|
||||||
|
// 创建本地引用以避免竞态条件
|
||||||
|
localSub := sub
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.closing:
|
||||||
|
s.logger.InfoContext(ctx, "subscriber is closing")
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
s.logger.InfoContext(ctx, "exiting on context closure")
|
||||||
|
return
|
||||||
|
case m, msgOk := <-localSub.Chan():
|
||||||
|
if !msgOk {
|
||||||
|
// Channel closed, exit the loop
|
||||||
|
s.logger.InfoContext(ctx, "consumer channel closed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go s.processMessage(ctx, output, m, localSub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Subscriber) processMessage(
|
||||||
|
ctx context.Context,
|
||||||
|
output chan *message.Message,
|
||||||
|
m pulsar.Message,
|
||||||
|
sub pulsar.Consumer,
|
||||||
|
) {
|
||||||
|
if s.isClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.logger.DebugContext(ctx, "Received message", "key", m.Key())
|
||||||
|
|
||||||
|
ctx, cancelCtx := context.WithCancel(ctx)
|
||||||
|
defer cancelCtx()
|
||||||
|
|
||||||
|
msg := message.NewMessage(m.Key(), m.Payload())
|
||||||
|
select {
|
||||||
|
case <-s.closing:
|
||||||
|
s.logger.DebugContext(ctx, "Closing, message discarded", "key", m.Key())
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
s.logger.DebugContext(ctx, "Context cancelled, message discarded")
|
||||||
|
return
|
||||||
|
// if this is first can risk 'send on closed channel' errors
|
||||||
|
case output <- msg:
|
||||||
|
s.logger.DebugContext(ctx, "Message sent to consumer")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-msg.Acked():
|
||||||
|
err := sub.Ack(m)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.DebugContext(ctx, "Message Ack Failed")
|
||||||
|
}
|
||||||
|
s.logger.DebugContext(ctx, "Message Acked")
|
||||||
|
case <-msg.Nacked():
|
||||||
|
sub.Nack(m)
|
||||||
|
s.logger.DebugContext(ctx, "Message Nacked")
|
||||||
|
case <-s.closing:
|
||||||
|
s.logger.DebugContext(ctx, "Closing, message discarded before ack")
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
s.logger.DebugContext(ctx, "Context cancelled, message discarded before ack")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the publisher and the underlying connection. It will attempt to wait for in-flight messages to complete.
|
||||||
|
func (s *Subscriber) Close() error {
|
||||||
|
s.subsLock.Lock()
|
||||||
|
defer s.subsLock.Unlock()
|
||||||
|
|
||||||
|
if s.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.closed = true
|
||||||
|
|
||||||
|
s.logger.DebugContext(context.Background(), "Closing subscriber")
|
||||||
|
defer s.logger.InfoContext(context.Background(), "Subscriber closed")
|
||||||
|
|
||||||
|
close(s.closing)
|
||||||
|
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
sub.Close()
|
||||||
|
}
|
||||||
|
s.conn.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Subscriber) isClosed() bool {
|
||||||
|
s.subsLock.RLock()
|
||||||
|
defer s.subsLock.RUnlock()
|
||||||
|
|
||||||
|
return s.closed
|
||||||
|
}
|
||||||
216
api/adapter/subscriber_advanced_test.go
Normal file
216
api/adapter/subscriber_advanced_test.go
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
package adapter_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_WithAllContextValues(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), adapter.SubNameKey, "custom-sub")
|
||||||
|
ctx = context.WithValue(ctx, adapter.IndexKey, 2)
|
||||||
|
ctx = context.WithValue(ctx, adapter.ReceiverQueueSizeKey, 1500)
|
||||||
|
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_ReuseExistingConsumer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Subscribe first time
|
||||||
|
msgChan1, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan1)
|
||||||
|
|
||||||
|
// Wait a bit
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Subscribe again with same topic - should reuse consumer
|
||||||
|
msgChan2, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_DifferentIndices(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx1 := context.WithValue(context.Background(), adapter.IndexKey, 0)
|
||||||
|
msgChan1, err := sub.Subscribe(ctx1, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan1)
|
||||||
|
|
||||||
|
ctx2 := context.WithValue(context.Background(), adapter.IndexKey, 1)
|
||||||
|
msgChan2, err := sub.Subscribe(ctx2, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_WithoutSubscriberName(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Close_WithMultipleSubscriptions(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = sub.Subscribe(ctx, "topic-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = sub.Subscribe(ctx, "topic-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_EmptyMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pub.Close()
|
||||||
|
|
||||||
|
// Publish with no messages - should succeed
|
||||||
|
err = pub.Publish("test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_MultipleMessagesSameTopic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pub.Close()
|
||||||
|
|
||||||
|
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||||
|
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||||
|
msg3 := message.NewMessage("uuid-3", []byte("payload-3"))
|
||||||
|
|
||||||
|
err = pub.Publish("test-topic", msg1, msg2, msg3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
producer := mockClient.GetProducer("test-topic")
|
||||||
|
require.NotNil(t, producer)
|
||||||
|
messages := producer.GetMessages()
|
||||||
|
require.Len(t, messages, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Close_WithMultipleProducers(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||||
|
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||||
|
|
||||||
|
err = pub.Publish("topic-1", msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Publish("topic-2", msg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Close_MultipleTimes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Close again should be safe
|
||||||
|
err = pub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
195
api/adapter/subscriber_edge_test.go
Normal file
195
api/adapter/subscriber_edge_test.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package adapter_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockPulsarClientWithSubscribeError is a mock client that can return subscription errors.
|
||||||
|
type MockPulsarClientWithSubscribeError struct {
|
||||||
|
*mocks.MockPulsarClient
|
||||||
|
|
||||||
|
subscribeError error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockPulsarClientWithSubscribeError() *MockPulsarClientWithSubscribeError {
|
||||||
|
return &MockPulsarClientWithSubscribeError{
|
||||||
|
MockPulsarClient: mocks.NewMockPulsarClient(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockPulsarClientWithSubscribeError) SetSubscribeError(err error) {
|
||||||
|
m.subscribeError = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockPulsarClientWithSubscribeError) Subscribe(options pulsar.ConsumerOptions) (pulsar.Consumer, error) {
|
||||||
|
if m.subscribeError != nil {
|
||||||
|
return nil, m.subscribeError
|
||||||
|
}
|
||||||
|
return m.MockPulsarClient.Subscribe(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_SubscriptionError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := NewMockPulsarClientWithSubscribeError()
|
||||||
|
mockClient.SetSubscribeError(errors.New("subscription failed"))
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "subscription failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_Timeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
// Use a very short timeout context that's already expired
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
_, err = sub.Subscribe(ctx, "test-topic")
|
||||||
|
// Should timeout or fail due to cancelled context
|
||||||
|
if err != nil {
|
||||||
|
assert.Contains(t, err.Error(), "timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_WithCustomSubName(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), adapter.SubNameKey, "custom-sub-name")
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_WithCustomIndex(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), adapter.IndexKey, 5)
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_WithCustomReceiverQueueSize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), adapter.ReceiverQueueSizeKey, 2000)
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_CreateProducerError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pub.Close()
|
||||||
|
|
||||||
|
// Close client before creating producer
|
||||||
|
mockClient.Close()
|
||||||
|
|
||||||
|
msg := message.NewMessage("test-uuid", []byte("test payload"))
|
||||||
|
err = pub.Publish("new-topic", msg)
|
||||||
|
// Should fail when creating producer
|
||||||
|
if err != nil {
|
||||||
|
assert.Contains(t, err.Error(), "closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublisher_Publish_SendError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pub.Close()
|
||||||
|
|
||||||
|
// Create a producer first
|
||||||
|
msg1 := message.NewMessage("test-uuid-1", []byte("test payload 1"))
|
||||||
|
err = pub.Publish("test-topic", msg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Close the producer to cause send error
|
||||||
|
producer := mockClient.GetProducer("test-topic")
|
||||||
|
require.NotNil(t, producer)
|
||||||
|
producer.Close()
|
||||||
|
|
||||||
|
msg2 := message.NewMessage("test-uuid-2", []byte("test payload 2"))
|
||||||
|
err = pub.Publish("test-topic", msg2)
|
||||||
|
// May succeed or fail depending on implementation
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
259
api/adapter/subscriber_test.go
Normal file
259
api/adapter/subscriber_test.go
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
package adapter_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSubscriberWithPulsarClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_WithContextValues(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Close(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Close_AfterSubscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_MultipleTopics(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
msgChan1, err := sub.Subscribe(ctx, "topic-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan1)
|
||||||
|
|
||||||
|
msgChan2, err := sub.Subscribe(ctx, "topic-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_ReuseConsumer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
msgChan1, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msgChan2, err := sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotNil(t, msgChan1)
|
||||||
|
assert.NotNil(t, msgChan2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_ContextCancellation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
_, err = sub.Subscribe(ctx, "test-topic")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cancel context
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Wait a bit for goroutine to process cancellation
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Close subscriber
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_EmptyTopic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
msgChan, err := sub.Subscribe(ctx, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, msgChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Close_MultipleTimes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Close again should be safe
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriber_Subscribe_AfterClose(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockClient := mocks.NewMockPulsarClient()
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = sub.Subscribe(ctx, "test-topic")
|
||||||
|
// Behavior depends on implementation - may succeed or fail
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSubscriber_InvalidURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := adapter.SubscriberConfig{
|
||||||
|
URL: "invalid-url",
|
||||||
|
SubscriberName: "test-sub",
|
||||||
|
SubscriberType: pulsar.Shared,
|
||||||
|
}
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
_, err := adapter.NewSubscriber(config, log)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "cannot connect")
|
||||||
|
}
|
||||||
229
api/adapter/tcp_integration_test.go
Normal file
229
api/adapter/tcp_integration_test.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 简单的测试日志适配器
|
||||||
|
type testLogger struct{}
|
||||||
|
|
||||||
|
func (t *testLogger) InfoContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) DebugContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) WarnContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) ErrorContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) Info(msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) Debug(msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) Warn(msg string, args ...interface{}) {}
|
||||||
|
func (t *testLogger) Error(msg string, args ...interface{}) {}
|
||||||
|
|
||||||
|
func TestTCPPublisherSubscriber_Integration(t *testing.T) {
|
||||||
|
testLogger := &testLogger{}
|
||||||
|
|
||||||
|
// 创建 Subscriber
|
||||||
|
subscriberConfig := TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:18080",
|
||||||
|
}
|
||||||
|
subscriber, err := NewTCPSubscriber(subscriberConfig, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create subscriber: %v", err)
|
||||||
|
}
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
// 等待服务器启动
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 订阅 topic
|
||||||
|
ctx := context.Background()
|
||||||
|
topic := "test-topic"
|
||||||
|
msgChan, err := subscriber.Subscribe(ctx, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to subscribe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
publisherConfig := TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:18080",
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := NewTCPPublisher(publisherConfig, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 测试发送和接收消息
|
||||||
|
testPayload := []byte("Hello, TCP Watermill!")
|
||||||
|
testMsg := message.NewMessage("test-msg-1", testPayload)
|
||||||
|
|
||||||
|
// 启动接收协程
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
select {
|
||||||
|
case receivedMsg := <-msgChan:
|
||||||
|
if string(receivedMsg.Payload) != string(testPayload) {
|
||||||
|
t.Errorf("Payload mismatch: got %s, want %s", receivedMsg.Payload, testPayload)
|
||||||
|
}
|
||||||
|
if receivedMsg.UUID != testMsg.UUID {
|
||||||
|
t.Errorf("UUID mismatch: got %s, want %s", receivedMsg.UUID, testMsg.UUID)
|
||||||
|
}
|
||||||
|
// ACK 消息
|
||||||
|
receivedMsg.Ack()
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("Timeout waiting for message")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
err = publisher.Publish(topic, testMsg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to publish message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待接收完成
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisherSubscriber_MultipleMessages(t *testing.T) {
|
||||||
|
testLogger := &testLogger{}
|
||||||
|
|
||||||
|
// 创建 Subscriber
|
||||||
|
subscriberConfig := TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:18081",
|
||||||
|
}
|
||||||
|
subscriber, err := NewTCPSubscriber(subscriberConfig, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create subscriber: %v", err)
|
||||||
|
}
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 订阅
|
||||||
|
ctx := context.Background()
|
||||||
|
topic := "test-topic-multi"
|
||||||
|
msgChan, err := subscriber.Subscribe(ctx, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to subscribe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
publisherConfig := TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:18081",
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := NewTCPPublisher(publisherConfig, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 准备多条消息
|
||||||
|
messageCount := 10
|
||||||
|
messages := make([]*message.Message, messageCount)
|
||||||
|
for i := 0; i < messageCount; i++ {
|
||||||
|
payload := []byte("Message " + string(rune('0'+i)))
|
||||||
|
messages[i] = message.NewMessage("msg-"+string(rune('0'+i)), payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动接收协程
|
||||||
|
receivedCount := 0
|
||||||
|
var mu sync.Mutex
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i < messageCount; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
select {
|
||||||
|
case receivedMsg := <-msgChan:
|
||||||
|
mu.Lock()
|
||||||
|
receivedCount++
|
||||||
|
mu.Unlock()
|
||||||
|
receivedMsg.Ack()
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Error("Timeout waiting for message")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息(并发发送)
|
||||||
|
err = publisher.Publish(topic, messages...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to publish messages: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待接收完成
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if receivedCount != messageCount {
|
||||||
|
t.Errorf("Received count mismatch: got %d, want %d", receivedCount, messageCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisherSubscriber_Nack(t *testing.T) {
|
||||||
|
testLogger := &testLogger{}
|
||||||
|
|
||||||
|
// 创建 Subscriber
|
||||||
|
subscriberConfig := TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:18082",
|
||||||
|
}
|
||||||
|
subscriber, err := NewTCPSubscriber(subscriberConfig, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create subscriber: %v", err)
|
||||||
|
}
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 订阅
|
||||||
|
ctx := context.Background()
|
||||||
|
topic := "test-topic-nack"
|
||||||
|
msgChan, err := subscriber.Subscribe(ctx, topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to subscribe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
publisherConfig := TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:18082",
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := NewTCPPublisher(publisherConfig, testLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create publisher: %v", err)
|
||||||
|
}
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 准备消息
|
||||||
|
testMsg := message.NewMessage("nack-test", []byte("This will be nacked"))
|
||||||
|
|
||||||
|
// 启动接收协程,这次 NACK
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
select {
|
||||||
|
case receivedMsg := <-msgChan:
|
||||||
|
// NACK 消息
|
||||||
|
receivedMsg.Nack()
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("Timeout waiting for message")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 发送消息,由于不等待ACK,应该立即返回成功
|
||||||
|
// 注意:即使消费者NACK,发布者也会返回成功
|
||||||
|
err = publisher.Publish(topic, testMsg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error (fire-and-forget), got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
149
api/adapter/tcp_protocol.go
Normal file
149
api/adapter/tcp_protocol.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 协议常量.
|
||||||
|
const (
|
||||||
|
// MessageTypeData 表示数据消息.
|
||||||
|
MessageTypeData byte = 0x01
|
||||||
|
// MessageTypeAck 表示 ACK 确认.
|
||||||
|
MessageTypeAck byte = 0x02
|
||||||
|
// MessageTypeNack 表示 NACK 否定确认.
|
||||||
|
MessageTypeNack byte = 0x03
|
||||||
|
|
||||||
|
// 协议限制.
|
||||||
|
maxTopicLength = 65535
|
||||||
|
maxUUIDLength = 255
|
||||||
|
maxPayloadSize = 1 << 30
|
||||||
|
topicLengthSize = 2
|
||||||
|
uuidLengthSize = 1
|
||||||
|
payloadLengthSize = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// 预定义错误.
|
||||||
|
var (
|
||||||
|
ErrNilMessage = errors.New("message is nil")
|
||||||
|
ErrTopicTooLong = errors.New("topic too long")
|
||||||
|
ErrUUIDTooLong = errors.New("uuid too long")
|
||||||
|
ErrPayloadTooLarge = errors.New("payload too large")
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPMessage 表示 TCP 传输的消息.
|
||||||
|
type TCPMessage struct {
|
||||||
|
Type byte // 消息类型
|
||||||
|
Topic string // 主题
|
||||||
|
UUID string // 消息 UUID
|
||||||
|
Payload []byte // 消息内容
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeTCPMessage 将消息编码为字节数组.
|
||||||
|
// 格式: [消息类型 1字节][Topic长度 2字节][Topic][UUID长度 1字节][UUID][Payload长度 4字节][Payload].
|
||||||
|
func EncodeTCPMessage(msg *TCPMessage) ([]byte, error) {
|
||||||
|
if msg == nil {
|
||||||
|
return nil, ErrNilMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
topicLen := len(msg.Topic)
|
||||||
|
if topicLen > maxTopicLength {
|
||||||
|
return nil, ErrTopicTooLong
|
||||||
|
}
|
||||||
|
|
||||||
|
uuidLen := len(msg.UUID)
|
||||||
|
if uuidLen > maxUUIDLength {
|
||||||
|
return nil, ErrUUIDTooLong
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadLen := len(msg.Payload)
|
||||||
|
if payloadLen > maxPayloadSize {
|
||||||
|
return nil, ErrPayloadTooLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总长度
|
||||||
|
totalLen := 1 + topicLengthSize + topicLen + uuidLengthSize + uuidLen + payloadLengthSize + payloadLen
|
||||||
|
buf := make([]byte, totalLen)
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
|
||||||
|
// 写入消息类型
|
||||||
|
buf[offset] = msg.Type
|
||||||
|
offset++
|
||||||
|
|
||||||
|
// 写入 Topic 长度和内容
|
||||||
|
binary.BigEndian.PutUint16(buf[offset:], uint16(topicLen))
|
||||||
|
offset += topicLengthSize
|
||||||
|
copy(buf[offset:], []byte(msg.Topic))
|
||||||
|
offset += topicLen
|
||||||
|
|
||||||
|
// 写入 UUID 长度和内容
|
||||||
|
buf[offset] = byte(uuidLen)
|
||||||
|
offset++
|
||||||
|
copy(buf[offset:], []byte(msg.UUID))
|
||||||
|
offset += uuidLen
|
||||||
|
|
||||||
|
// 写入 Payload 长度和内容
|
||||||
|
binary.BigEndian.PutUint32(buf[offset:], uint32(payloadLen))
|
||||||
|
offset += payloadLengthSize
|
||||||
|
copy(buf[offset:], msg.Payload)
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeTCPMessage 从字节数组解码消息.
|
||||||
|
func DecodeTCPMessage(reader io.Reader) (*TCPMessage, error) {
|
||||||
|
msg := &TCPMessage{}
|
||||||
|
|
||||||
|
// 读取消息类型
|
||||||
|
msgTypeBuf := make([]byte, 1)
|
||||||
|
if _, err := io.ReadFull(reader, msgTypeBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg.Type = msgTypeBuf[0]
|
||||||
|
|
||||||
|
// 读取 Topic 长度
|
||||||
|
topicLenBuf := make([]byte, topicLengthSize)
|
||||||
|
if _, err := io.ReadFull(reader, topicLenBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
topicLen := binary.BigEndian.Uint16(topicLenBuf)
|
||||||
|
|
||||||
|
// 读取 Topic
|
||||||
|
topicBuf := make([]byte, topicLen)
|
||||||
|
if _, err := io.ReadFull(reader, topicBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg.Topic = string(topicBuf)
|
||||||
|
|
||||||
|
// 读取 UUID 长度
|
||||||
|
uuidLenBuf := make([]byte, 1)
|
||||||
|
if _, err := io.ReadFull(reader, uuidLenBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
uuidLen := uuidLenBuf[0]
|
||||||
|
|
||||||
|
// 读取 UUID
|
||||||
|
uuidBuf := make([]byte, uuidLen)
|
||||||
|
if _, err := io.ReadFull(reader, uuidBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg.UUID = string(uuidBuf)
|
||||||
|
|
||||||
|
// 读取 Payload 长度
|
||||||
|
payloadLenBuf := make([]byte, payloadLengthSize)
|
||||||
|
if _, err := io.ReadFull(reader, payloadLenBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payloadLen := binary.BigEndian.Uint32(payloadLenBuf)
|
||||||
|
|
||||||
|
// 读取 Payload
|
||||||
|
payloadBuf := make([]byte, payloadLen)
|
||||||
|
if _, err := io.ReadFull(reader, payloadBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg.Payload = payloadBuf
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
166
api/adapter/tcp_protocol_test.go
Normal file
166
api/adapter/tcp_protocol_test.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeTCPMessage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msg *TCPMessage
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid data message",
|
||||||
|
msg: &TCPMessage{
|
||||||
|
Type: MessageTypeData,
|
||||||
|
Topic: "test-topic",
|
||||||
|
UUID: "test-uuid-1234",
|
||||||
|
Payload: []byte("test payload"),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid ack message",
|
||||||
|
msg: &TCPMessage{
|
||||||
|
Type: MessageTypeAck,
|
||||||
|
Topic: "",
|
||||||
|
UUID: "test-uuid-5678",
|
||||||
|
Payload: nil,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil message",
|
||||||
|
msg: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data, err := EncodeTCPMessage(tt.msg)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("EncodeTCPMessage() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr && data == nil {
|
||||||
|
t.Error("EncodeTCPMessage() returned nil data")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeTCPMessage(t *testing.T) {
|
||||||
|
// 创建一个测试消息
|
||||||
|
original := &TCPMessage{
|
||||||
|
Type: MessageTypeData,
|
||||||
|
Topic: "test-topic",
|
||||||
|
UUID: "test-uuid-1234",
|
||||||
|
Payload: []byte("test payload data"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 编码
|
||||||
|
encoded, err := EncodeTCPMessage(original)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to encode message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解码
|
||||||
|
reader := bytes.NewReader(encoded)
|
||||||
|
decoded, err := DecodeTCPMessage(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证
|
||||||
|
if decoded.Type != original.Type {
|
||||||
|
t.Errorf("Type mismatch: got %v, want %v", decoded.Type, original.Type)
|
||||||
|
}
|
||||||
|
if decoded.Topic != original.Topic {
|
||||||
|
t.Errorf("Topic mismatch: got %v, want %v", decoded.Topic, original.Topic)
|
||||||
|
}
|
||||||
|
if decoded.UUID != original.UUID {
|
||||||
|
t.Errorf("UUID mismatch: got %v, want %v", decoded.UUID, original.UUID)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(decoded.Payload, original.Payload) {
|
||||||
|
t.Errorf("Payload mismatch: got %v, want %v", decoded.Payload, original.Payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeRoundTrip(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
msg *TCPMessage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "data message with payload",
|
||||||
|
msg: &TCPMessage{
|
||||||
|
Type: MessageTypeData,
|
||||||
|
Topic: "persistent://public/default/test",
|
||||||
|
UUID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
Payload: []byte("Hello, World!"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ack message",
|
||||||
|
msg: &TCPMessage{
|
||||||
|
Type: MessageTypeAck,
|
||||||
|
Topic: "",
|
||||||
|
UUID: "test-uuid",
|
||||||
|
Payload: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nack message",
|
||||||
|
msg: &TCPMessage{
|
||||||
|
Type: MessageTypeNack,
|
||||||
|
Topic: "",
|
||||||
|
UUID: "another-uuid",
|
||||||
|
Payload: []byte{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "message with large payload",
|
||||||
|
msg: &TCPMessage{
|
||||||
|
Type: MessageTypeData,
|
||||||
|
Topic: "test",
|
||||||
|
UUID: "uuid",
|
||||||
|
Payload: make([]byte, 10000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// 编码
|
||||||
|
encoded, err := EncodeTCPMessage(tc.msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encode failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解码
|
||||||
|
reader := bytes.NewReader(encoded)
|
||||||
|
decoded, err := DecodeTCPMessage(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证所有字段
|
||||||
|
if decoded.Type != tc.msg.Type {
|
||||||
|
t.Errorf("Type: got %v, want %v", decoded.Type, tc.msg.Type)
|
||||||
|
}
|
||||||
|
if decoded.Topic != tc.msg.Topic {
|
||||||
|
t.Errorf("Topic: got %v, want %v", decoded.Topic, tc.msg.Topic)
|
||||||
|
}
|
||||||
|
if decoded.UUID != tc.msg.UUID {
|
||||||
|
t.Errorf("UUID: got %v, want %v", decoded.UUID, tc.msg.UUID)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(decoded.Payload, tc.msg.Payload) {
|
||||||
|
t.Errorf("Payload: got %v, want %v", decoded.Payload, tc.msg.Payload)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
195
api/adapter/tcp_publisher.go
Normal file
195
api/adapter/tcp_publisher.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 默认配置常量.
|
||||||
|
const (
|
||||||
|
defaultConnectTimeout = 10 * time.Second
|
||||||
|
defaultMaxRetries = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// 预定义错误.
|
||||||
|
var (
|
||||||
|
ErrServerAddrRequired = errors.New("server address is required")
|
||||||
|
ErrPublisherClosed = errors.New("publisher is closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPPublisherConfig TCP 发布者配置
|
||||||
|
type TCPPublisherConfig struct {
|
||||||
|
// ServerAddr TCP 服务器地址,格式: "host:port"
|
||||||
|
ServerAddr string
|
||||||
|
// ConnectTimeout 连接超时时间
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
// MaxRetries 最大重试次数
|
||||||
|
MaxRetries int
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPPublisher 实现基于 TCP 的 watermill Publisher
|
||||||
|
type TCPPublisher struct {
|
||||||
|
config TCPPublisherConfig
|
||||||
|
conn net.Conn
|
||||||
|
logger logger.Logger
|
||||||
|
|
||||||
|
closed bool
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
closeChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPPublisher 创建一个新的 TCP Publisher.
|
||||||
|
func NewTCPPublisher(config TCPPublisherConfig, logger logger.Logger) (*TCPPublisher, error) {
|
||||||
|
if config.ServerAddr == "" {
|
||||||
|
return nil, ErrServerAddrRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ConnectTimeout == 0 {
|
||||||
|
config.ConnectTimeout = defaultConnectTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.MaxRetries == 0 {
|
||||||
|
config.MaxRetries = defaultMaxRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &TCPPublisher{
|
||||||
|
config: config,
|
||||||
|
logger: logger,
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接到服务器
|
||||||
|
if err := p.connect(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不再接收 ACK/NACK,发送即成功模式
|
||||||
|
// go p.receiveAcks() // 已移除
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect 连接到 TCP 服务器
|
||||||
|
func (p *TCPPublisher) connect() error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), p.config.ConnectTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var d net.Dialer
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", p.config.ServerAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to %s: %w", p.config.ServerAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.conn = conn
|
||||||
|
p.logger.InfoContext(context.Background(), "Connected to TCP server", "addr", p.config.ServerAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish 发布消息.
|
||||||
|
func (p *TCPPublisher) Publish(topic string, messages ...*message.Message) error {
|
||||||
|
p.closedMu.RLock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.RUnlock()
|
||||||
|
return ErrPublisherClosed
|
||||||
|
}
|
||||||
|
p.closedMu.RUnlock()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 使用 WaitGroup 和 errChan 来并发发送消息并收集错误
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errs := make([]error, 0, len(messages))
|
||||||
|
var errMu sync.Mutex
|
||||||
|
errChan := make(chan error, len(messages))
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(m *message.Message) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if err := p.publishSingle(ctx, topic, m); err != nil {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待所有消息发送完成
|
||||||
|
wg.Wait()
|
||||||
|
close(errChan)
|
||||||
|
|
||||||
|
// 检查是否有错误
|
||||||
|
for err := range errChan {
|
||||||
|
errMu.Lock()
|
||||||
|
errs = append(errs, err)
|
||||||
|
errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return fmt.Errorf("failed to publish %d messages: %w", len(errs), errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// publishSingle 发送单条消息,不等待 ACK
|
||||||
|
func (p *TCPPublisher) publishSingle(ctx context.Context, topic string, msg *message.Message) error {
|
||||||
|
tcpMsg := &TCPMessage{
|
||||||
|
Type: MessageTypeData,
|
||||||
|
Topic: topic,
|
||||||
|
UUID: msg.UUID,
|
||||||
|
Payload: msg.Payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 编码消息
|
||||||
|
data, err := EncodeTCPMessage(tcpMsg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encode message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.DebugContext(ctx, "Sending message", "uuid", msg.UUID, "topic", topic)
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
if _, err := p.conn.Write(data); err != nil {
|
||||||
|
return fmt.Errorf("failed to write message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.DebugContext(ctx, "Message sent successfully", "uuid", msg.UUID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveAcks, shouldStopReceiving, handleDecodeError 方法已移除
|
||||||
|
// 不再接收 ACK/NACK,采用发送即成功模式以提高性能
|
||||||
|
|
||||||
|
// Close 关闭发布者
|
||||||
|
func (p *TCPPublisher) Close() error {
|
||||||
|
p.closedMu.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
|
||||||
|
close(p.closeChan)
|
||||||
|
|
||||||
|
if p.conn != nil {
|
||||||
|
if err := p.conn.Close(); err != nil {
|
||||||
|
return fmt.Errorf("failed to close connection: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.InfoContext(context.Background(), "TCP Publisher closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
246
api/adapter/tcp_publisher_test.go
Normal file
246
api/adapter/tcp_publisher_test.go
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package adapter_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 验证 TCPPublisher 实现了 message.Publisher 接口
|
||||||
|
func TestTCPPublisher_ImplementsPublisherInterface(t *testing.T) {
|
||||||
|
var _ message.Publisher = (*adapter.TCPPublisher)(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTCPPublisher_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 首先创建一个订阅者作为服务器
|
||||||
|
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:19090",
|
||||||
|
}
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
// 等待服务器启动
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19090",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, publisher)
|
||||||
|
|
||||||
|
err = publisher.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTCPPublisher_InvalidServerAddr(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
_, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, adapter.ErrServerAddrRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTCPPublisher_ConnectionFailed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 尝试连接到不存在的服务器
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19999",
|
||||||
|
ConnectTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
_, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to connect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisher_Publish_NoWaitForAck(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建订阅者
|
||||||
|
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:19091",
|
||||||
|
}
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19091",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 发送消息,应该立即返回成功,不等待ACK
|
||||||
|
msg := message.NewMessage("test-uuid-1", []byte("test payload"))
|
||||||
|
start := time.Now()
|
||||||
|
err = publisher.Publish("test-topic", msg)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
// 验证发送成功
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证发送速度很快(不应该等待ACK超时)
|
||||||
|
// 应该在100ms内返回(实际应该只需要几毫秒)
|
||||||
|
assert.Less(t, elapsed, 100*time.Millisecond, "Publish should return immediately without waiting for ACK")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisher_Publish_MultipleMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建订阅者
|
||||||
|
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:19092",
|
||||||
|
}
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19092",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 发送多条消息
|
||||||
|
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||||
|
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||||
|
msg3 := message.NewMessage("uuid-3", []byte("payload-3"))
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err = publisher.Publish("test-topic", msg1, msg2, msg3)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 发送3条消息应该很快完成
|
||||||
|
assert.Less(t, elapsed, 200*time.Millisecond, "Publishing multiple messages should be fast")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisher_Publish_AfterClose(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建订阅者
|
||||||
|
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:19093",
|
||||||
|
}
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19093",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 关闭 Publisher
|
||||||
|
err = publisher.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 尝试在关闭后发送消息
|
||||||
|
msg := message.NewMessage("uuid", []byte("payload"))
|
||||||
|
err = publisher.Publish("test-topic", msg)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, adapter.ErrPublisherClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisher_Publish_NilMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建订阅者
|
||||||
|
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:19094",
|
||||||
|
}
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19094",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer publisher.Close()
|
||||||
|
|
||||||
|
// 发送 nil 消息应该被忽略
|
||||||
|
err = publisher.Publish("test-topic", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPublisher_Close_Multiple(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := logger.NewNopLogger()
|
||||||
|
|
||||||
|
// 创建订阅者
|
||||||
|
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||||
|
ListenAddr: "127.0.0.1:19095",
|
||||||
|
}
|
||||||
|
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer subscriber.Close()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 创建 Publisher
|
||||||
|
config := adapter.TCPPublisherConfig{
|
||||||
|
ServerAddr: "127.0.0.1:19095",
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 多次关闭应该不会报错
|
||||||
|
err = publisher.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = publisher.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
310
api/adapter/tcp_subscriber.go
Normal file
310
api/adapter/tcp_subscriber.go
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 订阅者配置常量.
|
||||||
|
const (
|
||||||
|
defaultOutputChannelSize = 100
|
||||||
|
minOutputChannelSize = 10
|
||||||
|
maxOutputChannelSize = 10000
|
||||||
|
)
|
||||||
|
|
||||||
|
// 预定义错误.
|
||||||
|
var (
|
||||||
|
ErrListenAddrRequired = errors.New("listen address is required")
|
||||||
|
ErrSubscriberClosed = errors.New("subscriber is closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPSubscriberConfig TCP 订阅者配置
|
||||||
|
type TCPSubscriberConfig struct {
|
||||||
|
// ListenAddr 监听地址,格式: "host:port"
|
||||||
|
ListenAddr string
|
||||||
|
|
||||||
|
// OutputChannelSize 输出 channel 的缓冲大小
|
||||||
|
// 较小的值(如 10-50):更快的背压传递,但可能降低吞吐量
|
||||||
|
// 较大的值(如 500-1000):更高的吞吐量,但背压传递较慢
|
||||||
|
// 默认值:100(平衡吞吐量和背压)
|
||||||
|
OutputChannelSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPSubscriber 实现基于 TCP 的 watermill Subscriber
|
||||||
|
type TCPSubscriber struct {
|
||||||
|
config TCPSubscriberConfig
|
||||||
|
logger logger.Logger
|
||||||
|
listener net.Listener
|
||||||
|
|
||||||
|
subsLock sync.RWMutex
|
||||||
|
subs map[string][]chan *message.Message // topic -> channels
|
||||||
|
|
||||||
|
closed bool
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
closeChan chan struct{}
|
||||||
|
|
||||||
|
// 连接管理
|
||||||
|
connMu sync.Mutex
|
||||||
|
conns []net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPSubscriber 创建一个新的 TCP Subscriber.
|
||||||
|
func NewTCPSubscriber(config TCPSubscriberConfig, logger logger.Logger) (*TCPSubscriber, error) {
|
||||||
|
if config.ListenAddr == "" {
|
||||||
|
return nil, ErrListenAddrRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证和设置 channel 大小
|
||||||
|
channelSize := config.OutputChannelSize
|
||||||
|
if channelSize <= 0 {
|
||||||
|
channelSize = defaultOutputChannelSize
|
||||||
|
}
|
||||||
|
if channelSize < minOutputChannelSize {
|
||||||
|
channelSize = minOutputChannelSize
|
||||||
|
logger.WarnContext(context.Background(), "OutputChannelSize too small, using minimum",
|
||||||
|
"configured", config.OutputChannelSize, "actual", minOutputChannelSize)
|
||||||
|
}
|
||||||
|
if channelSize > maxOutputChannelSize {
|
||||||
|
channelSize = maxOutputChannelSize
|
||||||
|
logger.WarnContext(context.Background(), "OutputChannelSize too large, using maximum",
|
||||||
|
"configured", config.OutputChannelSize, "actual", maxOutputChannelSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", config.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on %s: %w", config.ListenAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新配置中的实际 channel 大小
|
||||||
|
config.OutputChannelSize = channelSize
|
||||||
|
|
||||||
|
s := &TCPSubscriber{
|
||||||
|
config: config,
|
||||||
|
logger: logger,
|
||||||
|
listener: listener,
|
||||||
|
subs: make(map[string][]chan *message.Message),
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
conns: make([]net.Conn, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动接受连接的协程
|
||||||
|
go s.acceptConnections()
|
||||||
|
|
||||||
|
logger.InfoContext(context.Background(), "TCP Subscriber listening",
|
||||||
|
"addr", config.ListenAddr,
|
||||||
|
"channel_size", channelSize)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptConnections 接受客户端连接
|
||||||
|
func (s *TCPSubscriber) acceptConnections() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.closeChan:
|
||||||
|
s.logger.InfoContext(ctx, "Stopping connection acceptor")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
s.closedMu.RLock()
|
||||||
|
closed := s.closed
|
||||||
|
s.closedMu.RUnlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.ErrorContext(ctx, "Failed to accept connection", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.InfoContext(ctx, "Accepted new connection", "remote", conn.RemoteAddr().String())
|
||||||
|
|
||||||
|
// 保存连接
|
||||||
|
s.connMu.Lock()
|
||||||
|
s.conns = append(s.conns, conn)
|
||||||
|
s.connMu.Unlock()
|
||||||
|
|
||||||
|
// 为每个连接启动处理协程
|
||||||
|
go s.handleConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnection 处理单个客户端连接
|
||||||
|
func (s *TCPSubscriber) handleConnection(conn net.Conn) {
|
||||||
|
ctx := context.Background()
|
||||||
|
defer func() {
|
||||||
|
conn.Close()
|
||||||
|
s.logger.InfoContext(ctx, "Connection closed", "remote", conn.RemoteAddr().String())
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.closeChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// 读取消息
|
||||||
|
tcpMsg, err := DecodeTCPMessage(conn)
|
||||||
|
if err != nil {
|
||||||
|
s.closedMu.RLock()
|
||||||
|
closed := s.closed
|
||||||
|
s.closedMu.RUnlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.ErrorContext(ctx, "Failed to decode message", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tcpMsg.Type != MessageTypeData {
|
||||||
|
s.logger.WarnContext(ctx, "Unexpected message type", "type", tcpMsg.Type)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理消息
|
||||||
|
s.handleMessage(ctx, conn, tcpMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMessage 处理消息(发送即成功模式,无需 ACK/NACK)
|
||||||
|
func (s *TCPSubscriber) handleMessage(ctx context.Context, conn net.Conn, tcpMsg *TCPMessage) {
|
||||||
|
s.logger.DebugContext(ctx, "Received message", "uuid", tcpMsg.UUID, "topic", tcpMsg.Topic)
|
||||||
|
|
||||||
|
// 获取该 topic 的订阅者
|
||||||
|
s.subsLock.RLock()
|
||||||
|
channels, found := s.subs[tcpMsg.Topic]
|
||||||
|
s.subsLock.RUnlock()
|
||||||
|
|
||||||
|
if !found || len(channels) == 0 {
|
||||||
|
s.logger.WarnContext(ctx, "No subscribers for topic", "topic", tcpMsg.Topic)
|
||||||
|
// 不再发送 NACK,直接丢弃消息
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 watermill 消息
|
||||||
|
msg := message.NewMessage(tcpMsg.UUID, tcpMsg.Payload)
|
||||||
|
|
||||||
|
// 使用随机策略选择订阅者(无锁,性能更好)
|
||||||
|
randomIndex := rand.Intn(len(channels))
|
||||||
|
outputChan := channels[randomIndex]
|
||||||
|
|
||||||
|
// 记录 channel 使用情况,便于监控背压
|
||||||
|
channelLen := len(outputChan)
|
||||||
|
channelCap := cap(outputChan)
|
||||||
|
usage := float64(channelLen) / float64(channelCap) * 100
|
||||||
|
|
||||||
|
s.logger.DebugContext(ctx, "Dispatching message via random selection",
|
||||||
|
"uuid", tcpMsg.UUID,
|
||||||
|
"subscriber_index", randomIndex,
|
||||||
|
"total_subscribers", len(channels),
|
||||||
|
"channel_usage", fmt.Sprintf("%.1f%% (%d/%d)", usage, channelLen, channelCap))
|
||||||
|
|
||||||
|
// 阻塞式发送:当 channel 满时会阻塞,从而触发 TCP 背压
|
||||||
|
// 这会导致:
|
||||||
|
// 1. 当前 goroutine 阻塞
|
||||||
|
// 2. TCP 读取停止
|
||||||
|
// 3. TCP 接收窗口填满
|
||||||
|
// 4. 发送端收到零窗口通知
|
||||||
|
// 5. 发送端停止发送
|
||||||
|
select {
|
||||||
|
case outputChan <- msg:
|
||||||
|
s.logger.DebugContext(ctx, "Message sent to subscriber", "uuid", tcpMsg.UUID, "index", randomIndex)
|
||||||
|
// 发送即成功:立即 Ack 消息,不等待处理结果
|
||||||
|
msg.Ack()
|
||||||
|
case <-s.closeChan:
|
||||||
|
s.logger.DebugContext(ctx, "Subscriber closed, message discarded", "uuid", tcpMsg.UUID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不再等待消息被 ACK 或 NACK,也不发送 ACK/NACK 回执
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendAck 方法已移除
|
||||||
|
// 采用发送即成功模式,不再发送 ACK/NACK 回执以提高性能
|
||||||
|
|
||||||
|
// Subscribe 订阅指定 topic 的消息.
|
||||||
|
func (s *TCPSubscriber) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
|
||||||
|
s.closedMu.RLock()
|
||||||
|
if s.closed {
|
||||||
|
s.closedMu.RUnlock()
|
||||||
|
return nil, ErrSubscriberClosed
|
||||||
|
}
|
||||||
|
s.closedMu.RUnlock()
|
||||||
|
|
||||||
|
// 使用配置的 channel 大小
|
||||||
|
channelSize := s.config.OutputChannelSize
|
||||||
|
if channelSize <= 0 {
|
||||||
|
channelSize = defaultOutputChannelSize
|
||||||
|
}
|
||||||
|
output := make(chan *message.Message, channelSize)
|
||||||
|
|
||||||
|
s.subsLock.Lock()
|
||||||
|
if s.subs[topic] == nil {
|
||||||
|
s.subs[topic] = make([]chan *message.Message, 0)
|
||||||
|
}
|
||||||
|
s.subs[topic] = append(s.subs[topic], output)
|
||||||
|
subscriberCount := len(s.subs[topic])
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
|
||||||
|
s.logger.InfoContext(ctx, "Subscribed to topic",
|
||||||
|
"topic", topic,
|
||||||
|
"subscriber_count", subscriberCount,
|
||||||
|
"channel_size", channelSize)
|
||||||
|
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭订阅者
|
||||||
|
func (s *TCPSubscriber) Close() error {
|
||||||
|
s.closedMu.Lock()
|
||||||
|
if s.closed {
|
||||||
|
s.closedMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.closed = true
|
||||||
|
s.closedMu.Unlock()
|
||||||
|
|
||||||
|
close(s.closeChan)
|
||||||
|
|
||||||
|
// 关闭监听器
|
||||||
|
if s.listener != nil {
|
||||||
|
if err := s.listener.Close(); err != nil {
|
||||||
|
s.logger.ErrorContext(context.Background(), "Failed to close listener", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭所有连接
|
||||||
|
s.connMu.Lock()
|
||||||
|
for _, conn := range s.conns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
s.connMu.Unlock()
|
||||||
|
|
||||||
|
// 关闭所有订阅通道
|
||||||
|
s.subsLock.Lock()
|
||||||
|
for topic, channels := range s.subs {
|
||||||
|
for _, ch := range channels {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
delete(s.subs, topic)
|
||||||
|
}
|
||||||
|
s.subsLock.Unlock()
|
||||||
|
|
||||||
|
s.logger.InfoContext(context.Background(), "TCP Subscriber closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
123
api/adapter/tls_config.go
Normal file
123
api/adapter/tls_config.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar/auth"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tlsConfigProvider defines the interface for TLS configuration.
|
||||||
|
type tlsConfigProvider interface {
|
||||||
|
GetTLSTrustCertsFilePath() string
|
||||||
|
GetTLSCertificateFilePath() string
|
||||||
|
GetTLSKeyFilePath() string
|
||||||
|
GetTLSAllowInsecureConnection() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureTLSForClient configures TLS/mTLS settings for the Pulsar client.
|
||||||
|
func configureTLSForClient(opts *pulsar.ClientOptions, config tlsConfigProvider, logger logger.Logger) error {
|
||||||
|
// If no TLS configuration is provided, skip TLS setup
|
||||||
|
if config.GetTLSTrustCertsFilePath() == "" &&
|
||||||
|
config.GetTLSCertificateFilePath() == "" &&
|
||||||
|
config.GetTLSKeyFilePath() == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure TLS trust certificates
|
||||||
|
if config.GetTLSTrustCertsFilePath() != "" {
|
||||||
|
if _, err := os.ReadFile(config.GetTLSTrustCertsFilePath()); err != nil {
|
||||||
|
return errors.Join(err, errors.New("failed to read TLS trust certificates file"))
|
||||||
|
}
|
||||||
|
opts.TLSTrustCertsFilePath = config.GetTLSTrustCertsFilePath()
|
||||||
|
logger.Debug(
|
||||||
|
"TLS trust certificates configured",
|
||||||
|
"path", config.GetTLSTrustCertsFilePath(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure TLS allow insecure connection
|
||||||
|
opts.TLSAllowInsecureConnection = config.GetTLSAllowInsecureConnection()
|
||||||
|
|
||||||
|
// Configure mTLS authentication if both certificate and key are provided
|
||||||
|
if config.GetTLSCertificateFilePath() != "" && config.GetTLSKeyFilePath() != "" {
|
||||||
|
// Load client certificate and key
|
||||||
|
cert, err := tls.LoadX509KeyPair(
|
||||||
|
config.GetTLSCertificateFilePath(),
|
||||||
|
config.GetTLSKeyFilePath(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Join(err, errors.New("failed to load client certificate and key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS authentication provider
|
||||||
|
// Pulsar Go client uses auth.NewAuthenticationTLS with certificate and key file paths
|
||||||
|
tlsAuth := auth.NewAuthenticationTLS(
|
||||||
|
config.GetTLSCertificateFilePath(),
|
||||||
|
config.GetTLSKeyFilePath(),
|
||||||
|
)
|
||||||
|
|
||||||
|
opts.Authentication = tlsAuth
|
||||||
|
logger.Debug(
|
||||||
|
"mTLS authentication configured",
|
||||||
|
"cert", config.GetTLSCertificateFilePath(),
|
||||||
|
"key", config.GetTLSKeyFilePath(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify the certificate is valid
|
||||||
|
if _, parseErr := x509.ParseCertificate(cert.Certificate[0]); parseErr != nil {
|
||||||
|
return errors.Join(parseErr, errors.New("invalid client certificate"))
|
||||||
|
}
|
||||||
|
} else if config.GetTLSCertificateFilePath() != "" || config.GetTLSKeyFilePath() != "" {
|
||||||
|
return errors.New(
|
||||||
|
"both TLS certificate and key file paths must be provided for mTLS authentication",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSTrustCertsFilePath returns the TLS trust certificates file path for PublisherConfig.
|
||||||
|
func (c PublisherConfig) GetTLSTrustCertsFilePath() string {
|
||||||
|
return c.TLSTrustCertsFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSCertificateFilePath returns the TLS certificate file path for PublisherConfig.
|
||||||
|
func (c PublisherConfig) GetTLSCertificateFilePath() string {
|
||||||
|
return c.TLSCertificateFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSKeyFilePath returns the TLS key file path for PublisherConfig.
|
||||||
|
func (c PublisherConfig) GetTLSKeyFilePath() string {
|
||||||
|
return c.TLSKeyFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSAllowInsecureConnection returns whether to allow insecure TLS connections for PublisherConfig.
|
||||||
|
func (c PublisherConfig) GetTLSAllowInsecureConnection() bool {
|
||||||
|
return c.TLSAllowInsecureConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSTrustCertsFilePath returns the TLS trust certificates file path for SubscriberConfig.
|
||||||
|
func (c SubscriberConfig) GetTLSTrustCertsFilePath() string {
|
||||||
|
return c.TLSTrustCertsFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSCertificateFilePath returns the TLS certificate file path for SubscriberConfig.
|
||||||
|
func (c SubscriberConfig) GetTLSCertificateFilePath() string {
|
||||||
|
return c.TLSCertificateFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSKeyFilePath returns the TLS key file path for SubscriberConfig.
|
||||||
|
func (c SubscriberConfig) GetTLSKeyFilePath() string {
|
||||||
|
return c.TLSKeyFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSAllowInsecureConnection returns whether to allow insecure TLS connections for SubscriberConfig.
|
||||||
|
func (c SubscriberConfig) GetTLSAllowInsecureConnection() bool {
|
||||||
|
return c.TLSAllowInsecureConnection
|
||||||
|
}
|
||||||
7
api/compressor/compressor.go
Normal file
7
api/compressor/compressor.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package compressor
|
||||||
|
|
||||||
|
// Compressor 压缩某种元素的列表,得到压缩结果和压缩副产物
|
||||||
|
// 为存证溯源定制数据压缩方式,并非真正的压缩,而是将数据进行证明其完整性的冗余数据剔除工作,保证上分布式账本一小部分即可完成存证
|
||||||
|
type Compressor[T any] interface {
|
||||||
|
Compress(data []byte) ([]byte, error)
|
||||||
|
}
|
||||||
20
api/grpc/common.proto
Normal file
20
api/grpc/common.proto
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package common;
|
||||||
|
|
||||||
|
option go_package = "go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb;pb";
|
||||||
|
|
||||||
|
message MerkleTreeProofItem {
|
||||||
|
uint32 floor = 1;
|
||||||
|
string hash = 2;
|
||||||
|
bool left = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Proof {
|
||||||
|
repeated MerkleTreeProofItem colItems = 1;
|
||||||
|
repeated MerkleTreeProofItem rawItems = 2;
|
||||||
|
repeated MerkleTreeProofItem colRootItem = 3;
|
||||||
|
repeated MerkleTreeProofItem rawRootItem = 4;
|
||||||
|
string sign = 5;
|
||||||
|
string version = 6; // 版本号
|
||||||
|
}
|
||||||
5
api/grpc/generator.go
Normal file
5
api/grpc/generator.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
//go:generate protoc --go_out=./pb --go-grpc_out=./pb --go_opt=module=go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb --go-grpc_opt=module=go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb --proto_path=. ./common.proto ./operation.proto ./record.proto
|
||||||
|
// 注意:common.proto 必须首先列出,因为 operation.proto 和 record.proto 都依赖它
|
||||||
|
// 生成的代码将包含 common.pb.go,其中定义了 Proof 类型
|
||||||
72
api/grpc/operation.proto
Normal file
72
api/grpc/operation.proto
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package operation;
|
||||||
|
|
||||||
|
option go_package = "go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb;pb";
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
import "common.proto";
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== 公共数据结构 ========================
|
||||||
|
message OperationData {
|
||||||
|
// 操作元数据信息
|
||||||
|
string op_id = 1; // 操作唯一标识符
|
||||||
|
google.protobuf.Timestamp timestamp = 2;// 操作时间戳
|
||||||
|
string op_source = 3; // 操作来源系统
|
||||||
|
string op_type = 4; // 操作类型
|
||||||
|
string do_prefix = 5; // 数据前缀标识符
|
||||||
|
string do_repository = 6; // 数据仓库标识符
|
||||||
|
string doid = 7; // 数据对象唯一标识
|
||||||
|
string producer_id = 8; // 生产者ID
|
||||||
|
string op_actor = 9; // 操作执行者信息
|
||||||
|
string request_body_hash = 10; // 请求体哈希值(可选)
|
||||||
|
string response_body_hash = 11; // 响应体哈希值(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== 验证请求 & 流式响应 ========================
|
||||||
|
message ValidationReq {
|
||||||
|
google.protobuf.Timestamp time = 1; // 操作时间戳(ISO8601格式)
|
||||||
|
string op_id = 2; // 操作唯一标识符
|
||||||
|
string op_type = 3; // 操作类型
|
||||||
|
string do_repository = 4; // 数据仓库标识
|
||||||
|
}
|
||||||
|
|
||||||
|
message ValidationStreamRes {
|
||||||
|
int32 code = 1; // 状态码(100处理中,200完成,500失败)
|
||||||
|
string msg = 2; // 消息描述
|
||||||
|
string progress = 3; // 当前进度(比如 "50%")
|
||||||
|
OperationData data = 4; // 最终完成时返回,过程可为空
|
||||||
|
common.Proof proof = 5; // 取证证明(仅在完成时返回)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== 列表查询请求 & 返回 ========================
|
||||||
|
message ListOperationReq {
|
||||||
|
// 分页条件
|
||||||
|
uint64 page_size = 1; // 页面大小
|
||||||
|
google.protobuf.Timestamp pre_time = 2; //上一页最后一个时间
|
||||||
|
|
||||||
|
// 可选条件
|
||||||
|
google.protobuf.Timestamp timestamp = 3;// 操作时间戳
|
||||||
|
string op_source = 4; // 操作来源
|
||||||
|
string op_type = 5; // 操作类型
|
||||||
|
string do_prefix = 6; // 数据前缀
|
||||||
|
string do_repository = 7; // 数据仓库
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListOperationRes {
|
||||||
|
int64 count=1; // 数据总量
|
||||||
|
repeated OperationData data = 2; // 数据列表
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== gRPC 服务定义 ========================
|
||||||
|
service OperationValidationService {
|
||||||
|
// 单个请求,服务端流式返回进度与最终结果
|
||||||
|
rpc ValidateOperation (ValidationReq) returns (stream ValidationStreamRes);
|
||||||
|
|
||||||
|
// 分页查询操作记录
|
||||||
|
rpc ListOperations (ListOperationReq) returns (ListOperationRes);
|
||||||
|
}
|
||||||
236
api/grpc/pb/common.pb.go
Normal file
236
api/grpc/pb/common.pb.go
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.36.10
|
||||||
|
// protoc v3.21.12
|
||||||
|
// source: common.proto
|
||||||
|
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
unsafe "unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
type MerkleTreeProofItem struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Floor uint32 `protobuf:"varint,1,opt,name=floor,proto3" json:"floor,omitempty"`
|
||||||
|
Hash string `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"`
|
||||||
|
Left bool `protobuf:"varint,3,opt,name=left,proto3" json:"left,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MerkleTreeProofItem) Reset() {
|
||||||
|
*x = MerkleTreeProofItem{}
|
||||||
|
mi := &file_common_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MerkleTreeProofItem) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*MerkleTreeProofItem) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *MerkleTreeProofItem) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_common_proto_msgTypes[0]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use MerkleTreeProofItem.ProtoReflect.Descriptor instead.
|
||||||
|
func (*MerkleTreeProofItem) Descriptor() ([]byte, []int) {
|
||||||
|
return file_common_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MerkleTreeProofItem) GetFloor() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Floor
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MerkleTreeProofItem) GetHash() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Hash
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *MerkleTreeProofItem) GetLeft() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Left
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type Proof struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
ColItems []*MerkleTreeProofItem `protobuf:"bytes,1,rep,name=colItems,proto3" json:"colItems,omitempty"`
|
||||||
|
RawItems []*MerkleTreeProofItem `protobuf:"bytes,2,rep,name=rawItems,proto3" json:"rawItems,omitempty"`
|
||||||
|
ColRootItem []*MerkleTreeProofItem `protobuf:"bytes,3,rep,name=colRootItem,proto3" json:"colRootItem,omitempty"`
|
||||||
|
RawRootItem []*MerkleTreeProofItem `protobuf:"bytes,4,rep,name=rawRootItem,proto3" json:"rawRootItem,omitempty"`
|
||||||
|
Sign string `protobuf:"bytes,5,opt,name=sign,proto3" json:"sign,omitempty"`
|
||||||
|
Version string `protobuf:"bytes,6,opt,name=version,proto3" json:"version,omitempty"` // 版本号
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) Reset() {
|
||||||
|
*x = Proof{}
|
||||||
|
mi := &file_common_proto_msgTypes[1]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Proof) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *Proof) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_common_proto_msgTypes[1]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use Proof.ProtoReflect.Descriptor instead.
|
||||||
|
func (*Proof) Descriptor() ([]byte, []int) {
|
||||||
|
return file_common_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) GetColItems() []*MerkleTreeProofItem {
|
||||||
|
if x != nil {
|
||||||
|
return x.ColItems
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) GetRawItems() []*MerkleTreeProofItem {
|
||||||
|
if x != nil {
|
||||||
|
return x.RawItems
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) GetColRootItem() []*MerkleTreeProofItem {
|
||||||
|
if x != nil {
|
||||||
|
return x.ColRootItem
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) GetRawRootItem() []*MerkleTreeProofItem {
|
||||||
|
if x != nil {
|
||||||
|
return x.RawRootItem
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) GetSign() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Sign
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Proof) GetVersion() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Version
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var File_common_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
|
const file_common_proto_rawDesc = "" +
|
||||||
|
"\n" +
|
||||||
|
"\fcommon.proto\x12\x06common\"S\n" +
|
||||||
|
"\x13MerkleTreeProofItem\x12\x14\n" +
|
||||||
|
"\x05floor\x18\x01 \x01(\rR\x05floor\x12\x12\n" +
|
||||||
|
"\x04hash\x18\x02 \x01(\tR\x04hash\x12\x12\n" +
|
||||||
|
"\x04left\x18\x03 \x01(\bR\x04left\"\xa5\x02\n" +
|
||||||
|
"\x05Proof\x127\n" +
|
||||||
|
"\bcolItems\x18\x01 \x03(\v2\x1b.common.MerkleTreeProofItemR\bcolItems\x127\n" +
|
||||||
|
"\brawItems\x18\x02 \x03(\v2\x1b.common.MerkleTreeProofItemR\brawItems\x12=\n" +
|
||||||
|
"\vcolRootItem\x18\x03 \x03(\v2\x1b.common.MerkleTreeProofItemR\vcolRootItem\x12=\n" +
|
||||||
|
"\vrawRootItem\x18\x04 \x03(\v2\x1b.common.MerkleTreeProofItemR\vrawRootItem\x12\x12\n" +
|
||||||
|
"\x04sign\x18\x05 \x01(\tR\x04sign\x12\x18\n" +
|
||||||
|
"\aversion\x18\x06 \x01(\tR\aversionB4Z2go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb;pbb\x06proto3"
|
||||||
|
|
||||||
|
var (
|
||||||
|
file_common_proto_rawDescOnce sync.Once
|
||||||
|
file_common_proto_rawDescData []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func file_common_proto_rawDescGZIP() []byte {
|
||||||
|
file_common_proto_rawDescOnce.Do(func() {
|
||||||
|
file_common_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_common_proto_rawDesc), len(file_common_proto_rawDesc)))
|
||||||
|
})
|
||||||
|
return file_common_proto_rawDescData
|
||||||
|
}
|
||||||
|
|
||||||
|
var file_common_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||||
|
var file_common_proto_goTypes = []any{
|
||||||
|
(*MerkleTreeProofItem)(nil), // 0: common.MerkleTreeProofItem
|
||||||
|
(*Proof)(nil), // 1: common.Proof
|
||||||
|
}
|
||||||
|
var file_common_proto_depIdxs = []int32{
|
||||||
|
0, // 0: common.Proof.colItems:type_name -> common.MerkleTreeProofItem
|
||||||
|
0, // 1: common.Proof.rawItems:type_name -> common.MerkleTreeProofItem
|
||||||
|
0, // 2: common.Proof.colRootItem:type_name -> common.MerkleTreeProofItem
|
||||||
|
0, // 3: common.Proof.rawRootItem:type_name -> common.MerkleTreeProofItem
|
||||||
|
4, // [4:4] is the sub-list for method output_type
|
||||||
|
4, // [4:4] is the sub-list for method input_type
|
||||||
|
4, // [4:4] is the sub-list for extension type_name
|
||||||
|
4, // [4:4] is the sub-list for extension extendee
|
||||||
|
0, // [0:4] is the sub-list for field type_name
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { file_common_proto_init() }
|
||||||
|
func file_common_proto_init() {
|
||||||
|
if File_common_proto != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_common_proto_rawDesc), len(file_common_proto_rawDesc)),
|
||||||
|
NumEnums: 0,
|
||||||
|
NumMessages: 2,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 0,
|
||||||
|
},
|
||||||
|
GoTypes: file_common_proto_goTypes,
|
||||||
|
DependencyIndexes: file_common_proto_depIdxs,
|
||||||
|
MessageInfos: file_common_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_common_proto = out.File
|
||||||
|
file_common_proto_goTypes = nil
|
||||||
|
file_common_proto_depIdxs = nil
|
||||||
|
}
|
||||||
552
api/grpc/pb/operation.pb.go
Normal file
552
api/grpc/pb/operation.pb.go
Normal file
@@ -0,0 +1,552 @@
|
|||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.36.10
|
||||||
|
// protoc v3.21.12
|
||||||
|
// source: operation.proto
|
||||||
|
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
unsafe "unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ======================== 公共数据结构 ========================
|
||||||
|
type OperationData struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// 操作元数据信息
|
||||||
|
OpId string `protobuf:"bytes,1,opt,name=op_id,json=opId,proto3" json:"op_id,omitempty"` // 操作唯一标识符
|
||||||
|
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` // 操作时间戳
|
||||||
|
OpSource string `protobuf:"bytes,3,opt,name=op_source,json=opSource,proto3" json:"op_source,omitempty"` // 操作来源系统
|
||||||
|
OpType string `protobuf:"bytes,4,opt,name=op_type,json=opType,proto3" json:"op_type,omitempty"` // 操作类型
|
||||||
|
DoPrefix string `protobuf:"bytes,5,opt,name=do_prefix,json=doPrefix,proto3" json:"do_prefix,omitempty"` // 数据前缀标识符
|
||||||
|
DoRepository string `protobuf:"bytes,6,opt,name=do_repository,json=doRepository,proto3" json:"do_repository,omitempty"` // 数据仓库标识符
|
||||||
|
Doid string `protobuf:"bytes,7,opt,name=doid,proto3" json:"doid,omitempty"` // 数据对象唯一标识
|
||||||
|
ProducerId string `protobuf:"bytes,8,opt,name=producer_id,json=producerId,proto3" json:"producer_id,omitempty"` // 生产者ID
|
||||||
|
OpActor string `protobuf:"bytes,9,opt,name=op_actor,json=opActor,proto3" json:"op_actor,omitempty"` // 操作执行者信息
|
||||||
|
RequestBodyHash string `protobuf:"bytes,10,opt,name=request_body_hash,json=requestBodyHash,proto3" json:"request_body_hash,omitempty"` // 请求体哈希值(可选)
|
||||||
|
ResponseBodyHash string `protobuf:"bytes,11,opt,name=response_body_hash,json=responseBodyHash,proto3" json:"response_body_hash,omitempty"` // 响应体哈希值(可选)
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) Reset() {
|
||||||
|
*x = OperationData{}
|
||||||
|
mi := &file_operation_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*OperationData) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *OperationData) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_operation_proto_msgTypes[0]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use OperationData.ProtoReflect.Descriptor instead.
|
||||||
|
func (*OperationData) Descriptor() ([]byte, []int) {
|
||||||
|
return file_operation_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetOpId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpId
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetTimestamp() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.Timestamp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetOpSource() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpSource
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetOpType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetDoPrefix() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoPrefix
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetDoRepository() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoRepository
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetDoid() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Doid
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetProducerId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ProducerId
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetOpActor() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpActor
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetRequestBodyHash() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.RequestBodyHash
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *OperationData) GetResponseBodyHash() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ResponseBodyHash
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================== 验证请求 & 流式响应 ========================
|
||||||
|
type ValidationReq struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Time *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=time,proto3" json:"time,omitempty"` // 操作时间戳(ISO8601格式)
|
||||||
|
OpId string `protobuf:"bytes,2,opt,name=op_id,json=opId,proto3" json:"op_id,omitempty"` // 操作唯一标识符
|
||||||
|
OpType string `protobuf:"bytes,3,opt,name=op_type,json=opType,proto3" json:"op_type,omitempty"` // 操作类型
|
||||||
|
DoRepository string `protobuf:"bytes,4,opt,name=do_repository,json=doRepository,proto3" json:"do_repository,omitempty"` // 数据仓库标识
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationReq) Reset() {
|
||||||
|
*x = ValidationReq{}
|
||||||
|
mi := &file_operation_proto_msgTypes[1]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationReq) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ValidationReq) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ValidationReq) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_operation_proto_msgTypes[1]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ValidationReq.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ValidationReq) Descriptor() ([]byte, []int) {
|
||||||
|
return file_operation_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationReq) GetTime() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.Time
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationReq) GetOpId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpId
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationReq) GetOpType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationReq) GetDoRepository() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoRepository
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ValidationStreamRes struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` // 状态码(100处理中,200完成,500失败)
|
||||||
|
Msg string `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` // 消息描述
|
||||||
|
Progress string `protobuf:"bytes,3,opt,name=progress,proto3" json:"progress,omitempty"` // 当前进度(比如 "50%")
|
||||||
|
Data *OperationData `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` // 最终完成时返回,过程可为空
|
||||||
|
Proof *Proof `protobuf:"bytes,5,opt,name=proof,proto3" json:"proof,omitempty"` // 取证证明(仅在完成时返回)
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) Reset() {
|
||||||
|
*x = ValidationStreamRes{}
|
||||||
|
mi := &file_operation_proto_msgTypes[2]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ValidationStreamRes) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_operation_proto_msgTypes[2]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ValidationStreamRes.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ValidationStreamRes) Descriptor() ([]byte, []int) {
|
||||||
|
return file_operation_proto_rawDescGZIP(), []int{2}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) GetCode() int32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Code
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) GetMsg() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Msg
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) GetProgress() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Progress
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) GetData() *OperationData {
|
||||||
|
if x != nil {
|
||||||
|
return x.Data
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ValidationStreamRes) GetProof() *Proof {
|
||||||
|
if x != nil {
|
||||||
|
return x.Proof
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================== 列表查询请求 & 返回 ========================
|
||||||
|
type ListOperationReq struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// 分页条件
|
||||||
|
PageSize uint64 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` // 页面大小
|
||||||
|
PreTime *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=pre_time,json=preTime,proto3" json:"pre_time,omitempty"` //上一页最后一个时间
|
||||||
|
// 可选条件
|
||||||
|
Timestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` // 操作时间戳
|
||||||
|
OpSource string `protobuf:"bytes,4,opt,name=op_source,json=opSource,proto3" json:"op_source,omitempty"` // 操作来源
|
||||||
|
OpType string `protobuf:"bytes,5,opt,name=op_type,json=opType,proto3" json:"op_type,omitempty"` // 操作类型
|
||||||
|
DoPrefix string `protobuf:"bytes,6,opt,name=do_prefix,json=doPrefix,proto3" json:"do_prefix,omitempty"` // 数据前缀
|
||||||
|
DoRepository string `protobuf:"bytes,7,opt,name=do_repository,json=doRepository,proto3" json:"do_repository,omitempty"` // 数据仓库
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) Reset() {
|
||||||
|
*x = ListOperationReq{}
|
||||||
|
mi := &file_operation_proto_msgTypes[3]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ListOperationReq) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_operation_proto_msgTypes[3]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ListOperationReq.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ListOperationReq) Descriptor() ([]byte, []int) {
|
||||||
|
return file_operation_proto_rawDescGZIP(), []int{3}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetPageSize() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.PageSize
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetPreTime() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.PreTime
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetTimestamp() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.Timestamp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetOpSource() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpSource
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetOpType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.OpType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetDoPrefix() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoPrefix
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationReq) GetDoRepository() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoRepository
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListOperationRes struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Count int64 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` // 数据总量
|
||||||
|
Data []*OperationData `protobuf:"bytes,2,rep,name=data,proto3" json:"data,omitempty"` // 数据列表
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationRes) Reset() {
|
||||||
|
*x = ListOperationRes{}
|
||||||
|
mi := &file_operation_proto_msgTypes[4]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationRes) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ListOperationRes) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ListOperationRes) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_operation_proto_msgTypes[4]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ListOperationRes.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ListOperationRes) Descriptor() ([]byte, []int) {
|
||||||
|
return file_operation_proto_rawDescGZIP(), []int{4}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationRes) GetCount() int64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Count
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListOperationRes) GetData() []*OperationData {
|
||||||
|
if x != nil {
|
||||||
|
return x.Data
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var File_operation_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
|
const file_operation_proto_rawDesc = "" +
|
||||||
|
"\n" +
|
||||||
|
"\x0foperation.proto\x12\toperation\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\fcommon.proto\"\x80\x03\n" +
|
||||||
|
"\rOperationData\x12\x13\n" +
|
||||||
|
"\x05op_id\x18\x01 \x01(\tR\x04opId\x128\n" +
|
||||||
|
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1b\n" +
|
||||||
|
"\top_source\x18\x03 \x01(\tR\bopSource\x12\x17\n" +
|
||||||
|
"\aop_type\x18\x04 \x01(\tR\x06opType\x12\x1b\n" +
|
||||||
|
"\tdo_prefix\x18\x05 \x01(\tR\bdoPrefix\x12#\n" +
|
||||||
|
"\rdo_repository\x18\x06 \x01(\tR\fdoRepository\x12\x12\n" +
|
||||||
|
"\x04doid\x18\a \x01(\tR\x04doid\x12\x1f\n" +
|
||||||
|
"\vproducer_id\x18\b \x01(\tR\n" +
|
||||||
|
"producerId\x12\x19\n" +
|
||||||
|
"\bop_actor\x18\t \x01(\tR\aopActor\x12*\n" +
|
||||||
|
"\x11request_body_hash\x18\n" +
|
||||||
|
" \x01(\tR\x0frequestBodyHash\x12,\n" +
|
||||||
|
"\x12response_body_hash\x18\v \x01(\tR\x10responseBodyHash\"\x92\x01\n" +
|
||||||
|
"\rValidationReq\x12.\n" +
|
||||||
|
"\x04time\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\x04time\x12\x13\n" +
|
||||||
|
"\x05op_id\x18\x02 \x01(\tR\x04opId\x12\x17\n" +
|
||||||
|
"\aop_type\x18\x03 \x01(\tR\x06opType\x12#\n" +
|
||||||
|
"\rdo_repository\x18\x04 \x01(\tR\fdoRepository\"\xaa\x01\n" +
|
||||||
|
"\x13ValidationStreamRes\x12\x12\n" +
|
||||||
|
"\x04code\x18\x01 \x01(\x05R\x04code\x12\x10\n" +
|
||||||
|
"\x03msg\x18\x02 \x01(\tR\x03msg\x12\x1a\n" +
|
||||||
|
"\bprogress\x18\x03 \x01(\tR\bprogress\x12,\n" +
|
||||||
|
"\x04data\x18\x04 \x01(\v2\x18.operation.OperationDataR\x04data\x12#\n" +
|
||||||
|
"\x05proof\x18\x05 \x01(\v2\r.common.ProofR\x05proof\"\x98\x02\n" +
|
||||||
|
"\x10ListOperationReq\x12\x1b\n" +
|
||||||
|
"\tpage_size\x18\x01 \x01(\x04R\bpageSize\x125\n" +
|
||||||
|
"\bpre_time\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\apreTime\x128\n" +
|
||||||
|
"\ttimestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1b\n" +
|
||||||
|
"\top_source\x18\x04 \x01(\tR\bopSource\x12\x17\n" +
|
||||||
|
"\aop_type\x18\x05 \x01(\tR\x06opType\x12\x1b\n" +
|
||||||
|
"\tdo_prefix\x18\x06 \x01(\tR\bdoPrefix\x12#\n" +
|
||||||
|
"\rdo_repository\x18\a \x01(\tR\fdoRepository\"V\n" +
|
||||||
|
"\x10ListOperationRes\x12\x14\n" +
|
||||||
|
"\x05count\x18\x01 \x01(\x03R\x05count\x12,\n" +
|
||||||
|
"\x04data\x18\x02 \x03(\v2\x18.operation.OperationDataR\x04data2\xb9\x01\n" +
|
||||||
|
"\x1aOperationValidationService\x12O\n" +
|
||||||
|
"\x11ValidateOperation\x12\x18.operation.ValidationReq\x1a\x1e.operation.ValidationStreamRes0\x01\x12J\n" +
|
||||||
|
"\x0eListOperations\x12\x1b.operation.ListOperationReq\x1a\x1b.operation.ListOperationResB4Z2go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb;pbb\x06proto3"
|
||||||
|
|
||||||
|
var (
|
||||||
|
file_operation_proto_rawDescOnce sync.Once
|
||||||
|
file_operation_proto_rawDescData []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func file_operation_proto_rawDescGZIP() []byte {
|
||||||
|
file_operation_proto_rawDescOnce.Do(func() {
|
||||||
|
file_operation_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_operation_proto_rawDesc), len(file_operation_proto_rawDesc)))
|
||||||
|
})
|
||||||
|
return file_operation_proto_rawDescData
|
||||||
|
}
|
||||||
|
|
||||||
|
var file_operation_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||||
|
var file_operation_proto_goTypes = []any{
|
||||||
|
(*OperationData)(nil), // 0: operation.OperationData
|
||||||
|
(*ValidationReq)(nil), // 1: operation.ValidationReq
|
||||||
|
(*ValidationStreamRes)(nil), // 2: operation.ValidationStreamRes
|
||||||
|
(*ListOperationReq)(nil), // 3: operation.ListOperationReq
|
||||||
|
(*ListOperationRes)(nil), // 4: operation.ListOperationRes
|
||||||
|
(*timestamppb.Timestamp)(nil), // 5: google.protobuf.Timestamp
|
||||||
|
(*Proof)(nil), // 6: common.Proof
|
||||||
|
}
|
||||||
|
var file_operation_proto_depIdxs = []int32{
|
||||||
|
5, // 0: operation.OperationData.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
|
5, // 1: operation.ValidationReq.time:type_name -> google.protobuf.Timestamp
|
||||||
|
0, // 2: operation.ValidationStreamRes.data:type_name -> operation.OperationData
|
||||||
|
6, // 3: operation.ValidationStreamRes.proof:type_name -> common.Proof
|
||||||
|
5, // 4: operation.ListOperationReq.pre_time:type_name -> google.protobuf.Timestamp
|
||||||
|
5, // 5: operation.ListOperationReq.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
|
0, // 6: operation.ListOperationRes.data:type_name -> operation.OperationData
|
||||||
|
1, // 7: operation.OperationValidationService.ValidateOperation:input_type -> operation.ValidationReq
|
||||||
|
3, // 8: operation.OperationValidationService.ListOperations:input_type -> operation.ListOperationReq
|
||||||
|
2, // 9: operation.OperationValidationService.ValidateOperation:output_type -> operation.ValidationStreamRes
|
||||||
|
4, // 10: operation.OperationValidationService.ListOperations:output_type -> operation.ListOperationRes
|
||||||
|
9, // [9:11] is the sub-list for method output_type
|
||||||
|
7, // [7:9] is the sub-list for method input_type
|
||||||
|
7, // [7:7] is the sub-list for extension type_name
|
||||||
|
7, // [7:7] is the sub-list for extension extendee
|
||||||
|
0, // [0:7] is the sub-list for field type_name
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { file_operation_proto_init() }
|
||||||
|
func file_operation_proto_init() {
|
||||||
|
if File_operation_proto != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file_common_proto_init()
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_operation_proto_rawDesc), len(file_operation_proto_rawDesc)),
|
||||||
|
NumEnums: 0,
|
||||||
|
NumMessages: 5,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 1,
|
||||||
|
},
|
||||||
|
GoTypes: file_operation_proto_goTypes,
|
||||||
|
DependencyIndexes: file_operation_proto_depIdxs,
|
||||||
|
MessageInfos: file_operation_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_operation_proto = out.File
|
||||||
|
file_operation_proto_goTypes = nil
|
||||||
|
file_operation_proto_depIdxs = nil
|
||||||
|
}
|
||||||
172
api/grpc/pb/operation_grpc.pb.go
Normal file
172
api/grpc/pb/operation_grpc.pb.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// - protoc-gen-go-grpc v1.5.1
|
||||||
|
// - protoc v3.21.12
|
||||||
|
// source: operation.proto
|
||||||
|
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
codes "google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
// Requires gRPC-Go v1.64.0 or later.
|
||||||
|
const _ = grpc.SupportPackageIsVersion9
|
||||||
|
|
||||||
|
const (
|
||||||
|
OperationValidationService_ValidateOperation_FullMethodName = "/operation.OperationValidationService/ValidateOperation"
|
||||||
|
OperationValidationService_ListOperations_FullMethodName = "/operation.OperationValidationService/ListOperations"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OperationValidationServiceClient is the client API for OperationValidationService service.
|
||||||
|
//
|
||||||
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
|
//
|
||||||
|
// ======================== gRPC 服务定义 ========================
|
||||||
|
type OperationValidationServiceClient interface {
|
||||||
|
// 单个请求,服务端流式返回进度与最终结果
|
||||||
|
ValidateOperation(ctx context.Context, in *ValidationReq, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ValidationStreamRes], error)
|
||||||
|
// 分页查询操作记录
|
||||||
|
ListOperations(ctx context.Context, in *ListOperationReq, opts ...grpc.CallOption) (*ListOperationRes, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type operationValidationServiceClient struct {
|
||||||
|
cc grpc.ClientConnInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOperationValidationServiceClient(cc grpc.ClientConnInterface) OperationValidationServiceClient {
|
||||||
|
return &operationValidationServiceClient{cc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *operationValidationServiceClient) ValidateOperation(ctx context.Context, in *ValidationReq, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ValidationStreamRes], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &OperationValidationService_ServiceDesc.Streams[0], OperationValidationService_ValidateOperation_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[ValidationReq, ValidationStreamRes]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type OperationValidationService_ValidateOperationClient = grpc.ServerStreamingClient[ValidationStreamRes]
|
||||||
|
|
||||||
|
func (c *operationValidationServiceClient) ListOperations(ctx context.Context, in *ListOperationReq, opts ...grpc.CallOption) (*ListOperationRes, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(ListOperationRes)
|
||||||
|
err := c.cc.Invoke(ctx, OperationValidationService_ListOperations_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OperationValidationServiceServer is the server API for OperationValidationService service.
|
||||||
|
// All implementations must embed UnimplementedOperationValidationServiceServer
|
||||||
|
// for forward compatibility.
|
||||||
|
//
|
||||||
|
// ======================== gRPC 服务定义 ========================
|
||||||
|
type OperationValidationServiceServer interface {
|
||||||
|
// 单个请求,服务端流式返回进度与最终结果
|
||||||
|
ValidateOperation(*ValidationReq, grpc.ServerStreamingServer[ValidationStreamRes]) error
|
||||||
|
// 分页查询操作记录
|
||||||
|
ListOperations(context.Context, *ListOperationReq) (*ListOperationRes, error)
|
||||||
|
mustEmbedUnimplementedOperationValidationServiceServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedOperationValidationServiceServer must be embedded to have
|
||||||
|
// forward compatible implementations.
|
||||||
|
//
|
||||||
|
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||||
|
// pointer dereference when methods are called.
|
||||||
|
type UnimplementedOperationValidationServiceServer struct{}
|
||||||
|
|
||||||
|
func (UnimplementedOperationValidationServiceServer) ValidateOperation(*ValidationReq, grpc.ServerStreamingServer[ValidationStreamRes]) error {
|
||||||
|
return status.Errorf(codes.Unimplemented, "method ValidateOperation not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedOperationValidationServiceServer) ListOperations(context.Context, *ListOperationReq) (*ListOperationRes, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method ListOperations not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedOperationValidationServiceServer) mustEmbedUnimplementedOperationValidationServiceServer() {
|
||||||
|
}
|
||||||
|
func (UnimplementedOperationValidationServiceServer) testEmbeddedByValue() {}
|
||||||
|
|
||||||
|
// UnsafeOperationValidationServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
|
// Use of this interface is not recommended, as added methods to OperationValidationServiceServer will
|
||||||
|
// result in compilation errors.
|
||||||
|
type UnsafeOperationValidationServiceServer interface {
|
||||||
|
mustEmbedUnimplementedOperationValidationServiceServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterOperationValidationServiceServer(s grpc.ServiceRegistrar, srv OperationValidationServiceServer) {
|
||||||
|
// If the following call pancis, it indicates UnimplementedOperationValidationServiceServer was
|
||||||
|
// embedded by pointer and is nil. This will cause panics if an
|
||||||
|
// unimplemented method is ever invoked, so we test this at initialization
|
||||||
|
// time to prevent it from happening at runtime later due to I/O.
|
||||||
|
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||||
|
t.testEmbeddedByValue()
|
||||||
|
}
|
||||||
|
s.RegisterService(&OperationValidationService_ServiceDesc, srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _OperationValidationService_ValidateOperation_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(ValidationReq)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(OperationValidationServiceServer).ValidateOperation(m, &grpc.GenericServerStream[ValidationReq, ValidationStreamRes]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type OperationValidationService_ValidateOperationServer = grpc.ServerStreamingServer[ValidationStreamRes]
|
||||||
|
|
||||||
|
func _OperationValidationService_ListOperations_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(ListOperationReq)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(OperationValidationServiceServer).ListOperations(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: OperationValidationService_ListOperations_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(OperationValidationServiceServer).ListOperations(ctx, req.(*ListOperationReq))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OperationValidationService_ServiceDesc is the grpc.ServiceDesc for OperationValidationService service.
|
||||||
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
|
// and not to be introspected or modified (even as a copy)
|
||||||
|
var OperationValidationService_ServiceDesc = grpc.ServiceDesc{
|
||||||
|
ServiceName: "operation.OperationValidationService",
|
||||||
|
HandlerType: (*OperationValidationServiceServer)(nil),
|
||||||
|
Methods: []grpc.MethodDesc{
|
||||||
|
{
|
||||||
|
MethodName: "ListOperations",
|
||||||
|
Handler: _OperationValidationService_ListOperations_Handler,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Streams: []grpc.StreamDesc{
|
||||||
|
{
|
||||||
|
StreamName: "ValidateOperation",
|
||||||
|
Handler: _OperationValidationService_ValidateOperation_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Metadata: "operation.proto",
|
||||||
|
}
|
||||||
489
api/grpc/pb/record.pb.go
Normal file
489
api/grpc/pb/record.pb.go
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.36.10
|
||||||
|
// protoc v3.21.12
|
||||||
|
// source: record.proto
|
||||||
|
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
unsafe "unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ======================== 公共数据结构 ========================
|
||||||
|
type RecordData struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// 记录核心信息
|
||||||
|
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // 记录唯一标识符(必填)
|
||||||
|
DoPrefix string `protobuf:"bytes,2,opt,name=do_prefix,json=doPrefix,proto3" json:"do_prefix,omitempty"` // 数据前缀标识符
|
||||||
|
ProducerId string `protobuf:"bytes,3,opt,name=producer_id,json=producerId,proto3" json:"producer_id,omitempty"` // 生产者ID
|
||||||
|
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` // 记录时间戳
|
||||||
|
Operator string `protobuf:"bytes,5,opt,name=operator,proto3" json:"operator,omitempty"` // 操作执行者标识
|
||||||
|
Extra []byte `protobuf:"bytes,6,opt,name=extra,proto3" json:"extra,omitempty"` // 额外数据字段
|
||||||
|
RcType string `protobuf:"bytes,7,opt,name=rc_type,json=rcType,proto3" json:"rc_type,omitempty"` // 记录类型
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) Reset() {
|
||||||
|
*x = RecordData{}
|
||||||
|
mi := &file_record_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RecordData) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *RecordData) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_record_proto_msgTypes[0]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use RecordData.ProtoReflect.Descriptor instead.
|
||||||
|
func (*RecordData) Descriptor() ([]byte, []int) {
|
||||||
|
return file_record_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Id
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetDoPrefix() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoPrefix
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetProducerId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ProducerId
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetTimestamp() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.Timestamp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetOperator() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Operator
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetExtra() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.Extra
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordData) GetRcType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.RcType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================== 列表查询请求 & 返回 ========================
|
||||||
|
type ListRecordReq struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// 分页条件
|
||||||
|
PageSize uint64 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` // 页面大小
|
||||||
|
PreTime *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=pre_time,json=preTime,proto3" json:"pre_time,omitempty"` // 上一页最后一个时间
|
||||||
|
// 可选过滤条件
|
||||||
|
DoPrefix string `protobuf:"bytes,3,opt,name=do_prefix,json=doPrefix,proto3" json:"do_prefix,omitempty"` // 数据前缀过滤
|
||||||
|
RcType string `protobuf:"bytes,4,opt,name=rc_type,json=rcType,proto3" json:"rc_type,omitempty"` // 记录类型过滤
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) Reset() {
|
||||||
|
*x = ListRecordReq{}
|
||||||
|
mi := &file_record_proto_msgTypes[1]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ListRecordReq) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_record_proto_msgTypes[1]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ListRecordReq.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ListRecordReq) Descriptor() ([]byte, []int) {
|
||||||
|
return file_record_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) GetPageSize() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.PageSize
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) GetPreTime() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.PreTime
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) GetDoPrefix() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoPrefix
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordReq) GetRcType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.RcType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListRecordRes struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Count int64 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` // 数据总量
|
||||||
|
Data []*RecordData `protobuf:"bytes,2,rep,name=data,proto3" json:"data,omitempty"` // 记录数据列表
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordRes) Reset() {
|
||||||
|
*x = ListRecordRes{}
|
||||||
|
mi := &file_record_proto_msgTypes[2]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordRes) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ListRecordRes) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ListRecordRes) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_record_proto_msgTypes[2]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ListRecordRes.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ListRecordRes) Descriptor() ([]byte, []int) {
|
||||||
|
return file_record_proto_rawDescGZIP(), []int{2}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordRes) GetCount() int64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Count
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListRecordRes) GetData() []*RecordData {
|
||||||
|
if x != nil {
|
||||||
|
return x.Data
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================== 记录验证请求 & 流式响应 ========================
|
||||||
|
type RecordValidationReq struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Timestamp *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` // 记录时间戳
|
||||||
|
RecordId string `protobuf:"bytes,2,opt,name=record_id,json=recordId,proto3" json:"record_id,omitempty"` // 要验证的记录ID
|
||||||
|
DoPrefix string `protobuf:"bytes,3,opt,name=do_prefix,json=doPrefix,proto3" json:"do_prefix,omitempty"` // 数据前缀(可选)
|
||||||
|
RcType string `protobuf:"bytes,4,opt,name=rc_type,json=rcType,proto3" json:"rc_type,omitempty"` // 记录类型
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) Reset() {
|
||||||
|
*x = RecordValidationReq{}
|
||||||
|
mi := &file_record_proto_msgTypes[3]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RecordValidationReq) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_record_proto_msgTypes[3]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use RecordValidationReq.ProtoReflect.Descriptor instead.
|
||||||
|
func (*RecordValidationReq) Descriptor() ([]byte, []int) {
|
||||||
|
return file_record_proto_rawDescGZIP(), []int{3}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) GetTimestamp() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.Timestamp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) GetRecordId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.RecordId
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) GetDoPrefix() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DoPrefix
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationReq) GetRcType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.RcType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecordValidationStreamRes struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` // 状态码(100处理中,200完成,400客户端错误,500服务器错误)
|
||||||
|
Msg string `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` // 消息描述
|
||||||
|
Progress string `protobuf:"bytes,3,opt,name=progress,proto3" json:"progress,omitempty"` // 验证进度(如 "30%", "验证哈希完成")
|
||||||
|
// 验证结果详情(仅在完成时返回)
|
||||||
|
Result *RecordData `protobuf:"bytes,4,opt,name=result,proto3" json:"result,omitempty"`
|
||||||
|
Proof *Proof `protobuf:"bytes,5,opt,name=proof,proto3" json:"proof,omitempty"` // 取证证明(仅在完成时返回)
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) Reset() {
|
||||||
|
*x = RecordValidationStreamRes{}
|
||||||
|
mi := &file_record_proto_msgTypes[4]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RecordValidationStreamRes) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_record_proto_msgTypes[4]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use RecordValidationStreamRes.ProtoReflect.Descriptor instead.
|
||||||
|
func (*RecordValidationStreamRes) Descriptor() ([]byte, []int) {
|
||||||
|
return file_record_proto_rawDescGZIP(), []int{4}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) GetCode() int32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Code
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) GetMsg() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Msg
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) GetProgress() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Progress
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) GetResult() *RecordData {
|
||||||
|
if x != nil {
|
||||||
|
return x.Result
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RecordValidationStreamRes) GetProof() *Proof {
|
||||||
|
if x != nil {
|
||||||
|
return x.Proof
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var File_record_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
|
const file_record_proto_rawDesc = "" +
|
||||||
|
"\n" +
|
||||||
|
"\frecord.proto\x12\x06record\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\fcommon.proto\"\xdf\x01\n" +
|
||||||
|
"\n" +
|
||||||
|
"RecordData\x12\x0e\n" +
|
||||||
|
"\x02id\x18\x01 \x01(\tR\x02id\x12\x1b\n" +
|
||||||
|
"\tdo_prefix\x18\x02 \x01(\tR\bdoPrefix\x12\x1f\n" +
|
||||||
|
"\vproducer_id\x18\x03 \x01(\tR\n" +
|
||||||
|
"producerId\x128\n" +
|
||||||
|
"\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1a\n" +
|
||||||
|
"\boperator\x18\x05 \x01(\tR\boperator\x12\x14\n" +
|
||||||
|
"\x05extra\x18\x06 \x01(\fR\x05extra\x12\x17\n" +
|
||||||
|
"\arc_type\x18\a \x01(\tR\x06rcType\"\x99\x01\n" +
|
||||||
|
"\rListRecordReq\x12\x1b\n" +
|
||||||
|
"\tpage_size\x18\x01 \x01(\x04R\bpageSize\x125\n" +
|
||||||
|
"\bpre_time\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\apreTime\x12\x1b\n" +
|
||||||
|
"\tdo_prefix\x18\x03 \x01(\tR\bdoPrefix\x12\x17\n" +
|
||||||
|
"\arc_type\x18\x04 \x01(\tR\x06rcType\"M\n" +
|
||||||
|
"\rListRecordRes\x12\x14\n" +
|
||||||
|
"\x05count\x18\x01 \x01(\x03R\x05count\x12&\n" +
|
||||||
|
"\x04data\x18\x02 \x03(\v2\x12.record.RecordDataR\x04data\"\xa2\x01\n" +
|
||||||
|
"\x13RecordValidationReq\x128\n" +
|
||||||
|
"\ttimestamp\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1b\n" +
|
||||||
|
"\trecord_id\x18\x02 \x01(\tR\brecordId\x12\x1b\n" +
|
||||||
|
"\tdo_prefix\x18\x03 \x01(\tR\bdoPrefix\x12\x17\n" +
|
||||||
|
"\arc_type\x18\x04 \x01(\tR\x06rcType\"\xae\x01\n" +
|
||||||
|
"\x19RecordValidationStreamRes\x12\x12\n" +
|
||||||
|
"\x04code\x18\x01 \x01(\x05R\x04code\x12\x10\n" +
|
||||||
|
"\x03msg\x18\x02 \x01(\tR\x03msg\x12\x1a\n" +
|
||||||
|
"\bprogress\x18\x03 \x01(\tR\bprogress\x12*\n" +
|
||||||
|
"\x06result\x18\x04 \x01(\v2\x12.record.RecordDataR\x06result\x12#\n" +
|
||||||
|
"\x05proof\x18\x05 \x01(\v2\r.common.ProofR\x05proof2\xaa\x01\n" +
|
||||||
|
"\x17RecordValidationService\x12;\n" +
|
||||||
|
"\vListRecords\x12\x15.record.ListRecordReq\x1a\x15.record.ListRecordRes\x12R\n" +
|
||||||
|
"\x0eValidateRecord\x12\x1b.record.RecordValidationReq\x1a!.record.RecordValidationStreamRes0\x01B4Z2go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb;pbb\x06proto3"
|
||||||
|
|
||||||
|
var (
|
||||||
|
file_record_proto_rawDescOnce sync.Once
|
||||||
|
file_record_proto_rawDescData []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func file_record_proto_rawDescGZIP() []byte {
|
||||||
|
file_record_proto_rawDescOnce.Do(func() {
|
||||||
|
file_record_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_record_proto_rawDesc), len(file_record_proto_rawDesc)))
|
||||||
|
})
|
||||||
|
return file_record_proto_rawDescData
|
||||||
|
}
|
||||||
|
|
||||||
|
var file_record_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||||
|
var file_record_proto_goTypes = []any{
|
||||||
|
(*RecordData)(nil), // 0: record.RecordData
|
||||||
|
(*ListRecordReq)(nil), // 1: record.ListRecordReq
|
||||||
|
(*ListRecordRes)(nil), // 2: record.ListRecordRes
|
||||||
|
(*RecordValidationReq)(nil), // 3: record.RecordValidationReq
|
||||||
|
(*RecordValidationStreamRes)(nil), // 4: record.RecordValidationStreamRes
|
||||||
|
(*timestamppb.Timestamp)(nil), // 5: google.protobuf.Timestamp
|
||||||
|
(*Proof)(nil), // 6: common.Proof
|
||||||
|
}
|
||||||
|
var file_record_proto_depIdxs = []int32{
|
||||||
|
5, // 0: record.RecordData.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
|
5, // 1: record.ListRecordReq.pre_time:type_name -> google.protobuf.Timestamp
|
||||||
|
0, // 2: record.ListRecordRes.data:type_name -> record.RecordData
|
||||||
|
5, // 3: record.RecordValidationReq.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
|
0, // 4: record.RecordValidationStreamRes.result:type_name -> record.RecordData
|
||||||
|
6, // 5: record.RecordValidationStreamRes.proof:type_name -> common.Proof
|
||||||
|
1, // 6: record.RecordValidationService.ListRecords:input_type -> record.ListRecordReq
|
||||||
|
3, // 7: record.RecordValidationService.ValidateRecord:input_type -> record.RecordValidationReq
|
||||||
|
2, // 8: record.RecordValidationService.ListRecords:output_type -> record.ListRecordRes
|
||||||
|
4, // 9: record.RecordValidationService.ValidateRecord:output_type -> record.RecordValidationStreamRes
|
||||||
|
8, // [8:10] is the sub-list for method output_type
|
||||||
|
6, // [6:8] is the sub-list for method input_type
|
||||||
|
6, // [6:6] is the sub-list for extension type_name
|
||||||
|
6, // [6:6] is the sub-list for extension extendee
|
||||||
|
0, // [0:6] is the sub-list for field type_name
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { file_record_proto_init() }
|
||||||
|
func file_record_proto_init() {
|
||||||
|
if File_record_proto != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file_common_proto_init()
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_record_proto_rawDesc), len(file_record_proto_rawDesc)),
|
||||||
|
NumEnums: 0,
|
||||||
|
NumMessages: 5,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 1,
|
||||||
|
},
|
||||||
|
GoTypes: file_record_proto_goTypes,
|
||||||
|
DependencyIndexes: file_record_proto_depIdxs,
|
||||||
|
MessageInfos: file_record_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_record_proto = out.File
|
||||||
|
file_record_proto_goTypes = nil
|
||||||
|
file_record_proto_depIdxs = nil
|
||||||
|
}
|
||||||
172
api/grpc/pb/record_grpc.pb.go
Normal file
172
api/grpc/pb/record_grpc.pb.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// - protoc-gen-go-grpc v1.5.1
|
||||||
|
// - protoc v3.21.12
|
||||||
|
// source: record.proto
|
||||||
|
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
codes "google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
// Requires gRPC-Go v1.64.0 or later.
|
||||||
|
const _ = grpc.SupportPackageIsVersion9
|
||||||
|
|
||||||
|
const (
|
||||||
|
RecordValidationService_ListRecords_FullMethodName = "/record.RecordValidationService/ListRecords"
|
||||||
|
RecordValidationService_ValidateRecord_FullMethodName = "/record.RecordValidationService/ValidateRecord"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecordValidationServiceClient is the client API for RecordValidationService service.
|
||||||
|
//
|
||||||
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
|
//
|
||||||
|
// ======================== gRPC 服务定义 ========================
|
||||||
|
type RecordValidationServiceClient interface {
|
||||||
|
// 分页查询记录列表
|
||||||
|
ListRecords(ctx context.Context, in *ListRecordReq, opts ...grpc.CallOption) (*ListRecordRes, error)
|
||||||
|
// 单个记录验证,服务端流式返回验证进度与结果
|
||||||
|
ValidateRecord(ctx context.Context, in *RecordValidationReq, opts ...grpc.CallOption) (grpc.ServerStreamingClient[RecordValidationStreamRes], error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordValidationServiceClient struct {
|
||||||
|
cc grpc.ClientConnInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecordValidationServiceClient(cc grpc.ClientConnInterface) RecordValidationServiceClient {
|
||||||
|
return &recordValidationServiceClient{cc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordValidationServiceClient) ListRecords(ctx context.Context, in *ListRecordReq, opts ...grpc.CallOption) (*ListRecordRes, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(ListRecordRes)
|
||||||
|
err := c.cc.Invoke(ctx, RecordValidationService_ListRecords_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordValidationServiceClient) ValidateRecord(ctx context.Context, in *RecordValidationReq, opts ...grpc.CallOption) (grpc.ServerStreamingClient[RecordValidationStreamRes], error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
stream, err := c.cc.NewStream(ctx, &RecordValidationService_ServiceDesc.Streams[0], RecordValidationService_ValidateRecord_FullMethodName, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &grpc.GenericClientStream[RecordValidationReq, RecordValidationStreamRes]{ClientStream: stream}
|
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type RecordValidationService_ValidateRecordClient = grpc.ServerStreamingClient[RecordValidationStreamRes]
|
||||||
|
|
||||||
|
// RecordValidationServiceServer is the server API for RecordValidationService service.
|
||||||
|
// All implementations must embed UnimplementedRecordValidationServiceServer
|
||||||
|
// for forward compatibility.
|
||||||
|
//
|
||||||
|
// ======================== gRPC 服务定义 ========================
|
||||||
|
type RecordValidationServiceServer interface {
|
||||||
|
// 分页查询记录列表
|
||||||
|
ListRecords(context.Context, *ListRecordReq) (*ListRecordRes, error)
|
||||||
|
// 单个记录验证,服务端流式返回验证进度与结果
|
||||||
|
ValidateRecord(*RecordValidationReq, grpc.ServerStreamingServer[RecordValidationStreamRes]) error
|
||||||
|
mustEmbedUnimplementedRecordValidationServiceServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedRecordValidationServiceServer must be embedded to have
|
||||||
|
// forward compatible implementations.
|
||||||
|
//
|
||||||
|
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||||
|
// pointer dereference when methods are called.
|
||||||
|
type UnimplementedRecordValidationServiceServer struct{}
|
||||||
|
|
||||||
|
func (UnimplementedRecordValidationServiceServer) ListRecords(context.Context, *ListRecordReq) (*ListRecordRes, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method ListRecords not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedRecordValidationServiceServer) ValidateRecord(*RecordValidationReq, grpc.ServerStreamingServer[RecordValidationStreamRes]) error {
|
||||||
|
return status.Errorf(codes.Unimplemented, "method ValidateRecord not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedRecordValidationServiceServer) mustEmbedUnimplementedRecordValidationServiceServer() {
|
||||||
|
}
|
||||||
|
func (UnimplementedRecordValidationServiceServer) testEmbeddedByValue() {}
|
||||||
|
|
||||||
|
// UnsafeRecordValidationServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
|
// Use of this interface is not recommended, as added methods to RecordValidationServiceServer will
|
||||||
|
// result in compilation errors.
|
||||||
|
type UnsafeRecordValidationServiceServer interface {
|
||||||
|
mustEmbedUnimplementedRecordValidationServiceServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterRecordValidationServiceServer(s grpc.ServiceRegistrar, srv RecordValidationServiceServer) {
|
||||||
|
// If the following call pancis, it indicates UnimplementedRecordValidationServiceServer was
|
||||||
|
// embedded by pointer and is nil. This will cause panics if an
|
||||||
|
// unimplemented method is ever invoked, so we test this at initialization
|
||||||
|
// time to prevent it from happening at runtime later due to I/O.
|
||||||
|
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||||
|
t.testEmbeddedByValue()
|
||||||
|
}
|
||||||
|
s.RegisterService(&RecordValidationService_ServiceDesc, srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _RecordValidationService_ListRecords_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(ListRecordReq)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(RecordValidationServiceServer).ListRecords(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: RecordValidationService_ListRecords_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(RecordValidationServiceServer).ListRecords(ctx, req.(*ListRecordReq))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _RecordValidationService_ValidateRecord_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
m := new(RecordValidationReq)
|
||||||
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.(RecordValidationServiceServer).ValidateRecord(m, &grpc.GenericServerStream[RecordValidationReq, RecordValidationStreamRes]{ServerStream: stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||||
|
type RecordValidationService_ValidateRecordServer = grpc.ServerStreamingServer[RecordValidationStreamRes]
|
||||||
|
|
||||||
|
// RecordValidationService_ServiceDesc is the grpc.ServiceDesc for RecordValidationService service.
|
||||||
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
|
// and not to be introspected or modified (even as a copy)
|
||||||
|
var RecordValidationService_ServiceDesc = grpc.ServiceDesc{
|
||||||
|
ServiceName: "record.RecordValidationService",
|
||||||
|
HandlerType: (*RecordValidationServiceServer)(nil),
|
||||||
|
Methods: []grpc.MethodDesc{
|
||||||
|
{
|
||||||
|
MethodName: "ListRecords",
|
||||||
|
Handler: _RecordValidationService_ListRecords_Handler,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Streams: []grpc.StreamDesc{
|
||||||
|
{
|
||||||
|
StreamName: "ValidateRecord",
|
||||||
|
Handler: _RecordValidationService_ValidateRecord_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Metadata: "record.proto",
|
||||||
|
}
|
||||||
67
api/grpc/record.proto
Normal file
67
api/grpc/record.proto
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package record;
|
||||||
|
|
||||||
|
option go_package = "go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb;pb";
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
import "common.proto";
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== 公共数据结构 ========================
|
||||||
|
message RecordData {
|
||||||
|
// 记录核心信息
|
||||||
|
string id = 1; // 记录唯一标识符(必填)
|
||||||
|
string do_prefix = 2; // 数据前缀标识符
|
||||||
|
string producer_id = 3; // 生产者ID
|
||||||
|
google.protobuf.Timestamp timestamp = 4;// 记录时间戳
|
||||||
|
string operator = 5; // 操作执行者标识
|
||||||
|
bytes extra = 6; // 额外数据字段
|
||||||
|
string rc_type = 7; // 记录类型
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== 列表查询请求 & 返回 ========================
|
||||||
|
message ListRecordReq {
|
||||||
|
// 分页条件
|
||||||
|
uint64 page_size = 1; // 页面大小
|
||||||
|
google.protobuf.Timestamp pre_time = 2; // 上一页最后一个时间
|
||||||
|
|
||||||
|
// 可选过滤条件
|
||||||
|
string do_prefix = 3; // 数据前缀过滤
|
||||||
|
string rc_type = 4; // 记录类型过滤
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListRecordRes {
|
||||||
|
int64 count = 1; // 数据总量
|
||||||
|
repeated RecordData data = 2; // 记录数据列表
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== 记录验证请求 & 流式响应 ========================
|
||||||
|
message RecordValidationReq {
|
||||||
|
google.protobuf.Timestamp timestamp = 1;// 记录时间戳
|
||||||
|
string record_id = 2; // 要验证的记录ID
|
||||||
|
string do_prefix = 3; // 数据前缀(可选)
|
||||||
|
string rc_type = 4; // 记录类型
|
||||||
|
}
|
||||||
|
|
||||||
|
message RecordValidationStreamRes {
|
||||||
|
int32 code = 1; // 状态码(100处理中,200完成,400客户端错误,500服务器错误)
|
||||||
|
string msg = 2; // 消息描述
|
||||||
|
string progress = 3; // 验证进度(如 "30%", "验证哈希完成")
|
||||||
|
|
||||||
|
// 验证结果详情(仅在完成时返回)
|
||||||
|
RecordData result = 4;
|
||||||
|
common.Proof proof = 5; // 取证证明(仅在完成时返回)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ======================== gRPC 服务定义 ========================
|
||||||
|
service RecordValidationService {
|
||||||
|
// 分页查询记录列表
|
||||||
|
rpc ListRecords (ListRecordReq) returns (ListRecordRes);
|
||||||
|
|
||||||
|
// 单个记录验证,服务端流式返回验证进度与结果
|
||||||
|
rpc ValidateRecord (RecordValidationReq) returns (stream RecordValidationStreamRes);
|
||||||
|
}
|
||||||
156
api/highclient/client.go
Normal file
156
api/highclient/client.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package highclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
publisher message.Publisher
|
||||||
|
logger logger.Logger
|
||||||
|
envelopeConfig model.EnvelopeConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient 创建HighClient,使用Envelope序列化方式.
|
||||||
|
// publisher可以使用任意(包含forwarder)创建的publisher,但是我们所有的订阅者必须可以处理Envelope格式的消息.
|
||||||
|
// 参数:
|
||||||
|
// - publisher: 消息发布器
|
||||||
|
// - logger: 日志记录器
|
||||||
|
// - envelopeConfig: SM2密钥配置,用于签名和序列化
|
||||||
|
func NewClient(publisher message.Publisher, logger logger.Logger, envelopeConfig model.EnvelopeConfig) *Client {
|
||||||
|
return &Client{
|
||||||
|
publisher: publisher,
|
||||||
|
logger: logger,
|
||||||
|
envelopeConfig: envelopeConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) GetLow() message.Publisher {
|
||||||
|
return c.publisher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) OperationPublish(operation *model.Operation) error {
|
||||||
|
if operation == nil {
|
||||||
|
c.logger.Error("operation publish failed: operation is nil")
|
||||||
|
return errors.New("operation cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Debug("publishing operation",
|
||||||
|
"opID", operation.OpID,
|
||||||
|
"opType", operation.OpType,
|
||||||
|
"doPrefix", operation.DoPrefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
err := publish(operation, adapter.OperationTopic, c.publisher, c.envelopeConfig, c.logger)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("operation publish failed",
|
||||||
|
"opID", operation.OpID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("operation published successfully",
|
||||||
|
"opID", operation.OpID,
|
||||||
|
"opType", operation.OpType,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) RecordPublish(record *model.Record) error {
|
||||||
|
if record == nil {
|
||||||
|
c.logger.Error("record publish failed: record is nil")
|
||||||
|
return errors.New("record cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Debug("publishing record",
|
||||||
|
"recordID", record.ID,
|
||||||
|
"rcType", record.RCType,
|
||||||
|
"doPrefix", record.DoPrefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
err := publish(record, adapter.RecordTopic, c.publisher, c.envelopeConfig, c.logger)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("record publish failed",
|
||||||
|
"recordID", record.ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("record published successfully",
|
||||||
|
"recordID", record.ID,
|
||||||
|
"rcType", record.RCType,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
c.logger.Info("closing high client")
|
||||||
|
err := c.publisher.Close()
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to close publisher", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logger.Info("high client closed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// publish 通用的发布函数,支持任何实现了 Trustlog 接口的类型。
|
||||||
|
// 使用 Envelope 格式序列化并发布到指定 topic。
|
||||||
|
func publish(
|
||||||
|
data model.Trustlog,
|
||||||
|
topic string,
|
||||||
|
publisher message.Publisher,
|
||||||
|
config model.EnvelopeConfig,
|
||||||
|
logger logger.Logger,
|
||||||
|
) error {
|
||||||
|
messageKey := data.Key()
|
||||||
|
|
||||||
|
logger.Debug("starting envelope serialization",
|
||||||
|
"messageKey", messageKey,
|
||||||
|
"topic", topic,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 使用 Envelope 序列化(MarshalTrustlog 会自动提取 producerID)
|
||||||
|
envelopeData, err := model.MarshalTrustlog(data, config)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("envelope serialization failed",
|
||||||
|
"messageKey", messageKey,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return fmt.Errorf("failed to marshal envelope: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("envelope serialized successfully",
|
||||||
|
"messageKey", messageKey,
|
||||||
|
"envelopeSize", len(envelopeData),
|
||||||
|
)
|
||||||
|
|
||||||
|
msg := message.NewMessage(messageKey, envelopeData)
|
||||||
|
logger.Debug("publishing message to topic",
|
||||||
|
"messageKey", messageKey,
|
||||||
|
"topic", topic,
|
||||||
|
)
|
||||||
|
|
||||||
|
if publishErr := publisher.Publish(topic, msg); publishErr != nil {
|
||||||
|
logger.Error("failed to publish to topic",
|
||||||
|
"messageKey", messageKey,
|
||||||
|
"topic", topic,
|
||||||
|
"error", publishErr,
|
||||||
|
)
|
||||||
|
return fmt.Errorf("failed to publish message to topic %s: %w", topic, publishErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("message published to topic successfully",
|
||||||
|
"messageKey", messageKey,
|
||||||
|
"topic", topic,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
536
api/highclient/client_test.go
Normal file
536
api/highclient/client_test.go
Normal file
@@ -0,0 +1,536 @@
|
|||||||
|
package highclient_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill/message"
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/highclient"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockPublisher 模拟 message.Publisher.
|
||||||
|
type MockPublisher struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockPublisher) Publish(topic string, messages ...*message.Message) error {
|
||||||
|
args := m.Called(topic, messages)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockPublisher) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestKeys 生成测试用的SM2密钥对(DER格式).
|
||||||
|
func generateTestKeys(t testing.TB) ([]byte, []byte) {
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
if err != nil {
|
||||||
|
if t != nil {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 私钥:DER编码
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
if err != nil {
|
||||||
|
if t != nil {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 公钥:DER编码
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
if err != nil {
|
||||||
|
if t != nil {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKeyDER, publicKeyDER
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClient(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
require.NotNil(t, client)
|
||||||
|
assert.Equal(t, mockPublisher, client.GetLow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetLow(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
lowLevelPublisher := client.GetLow()
|
||||||
|
assert.Equal(t, mockPublisher, lowLevelPublisher)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_OperationPublish(t *testing.T) { //nolint:dupl // 测试代码中的重复模式是合理的
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
operation *model.Operation
|
||||||
|
setupMock func(*MockPublisher)
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "成功发布Operation",
|
||||||
|
operation: createTestOperation(t),
|
||||||
|
setupMock: func(mp *MockPublisher) {
|
||||||
|
mp.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "发布失败",
|
||||||
|
operation: createTestOperation(t),
|
||||||
|
setupMock: func(mp *MockPublisher) {
|
||||||
|
mp.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).
|
||||||
|
Return(errors.New("publish failed")).
|
||||||
|
Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "publish failed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil Operation应该失败",
|
||||||
|
operation: nil,
|
||||||
|
setupMock: func(_ *MockPublisher) {
|
||||||
|
// nil operation不会调用Publish,因为会在之前失败
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
tt.setupMock(mockPublisher)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
err := client.OperationPublish(tt.operation)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errContains != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errContains)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_RecordPublish(t *testing.T) { //nolint:dupl // 测试代码中的重复模式是合理的
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
record *model.Record
|
||||||
|
setupMock func(*MockPublisher)
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "成功发布Record",
|
||||||
|
record: createTestRecord(t),
|
||||||
|
setupMock: func(mp *MockPublisher) {
|
||||||
|
mp.On("Publish", adapter.RecordTopic, mock.AnythingOfType("[]*message.Message")).Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "发布失败",
|
||||||
|
record: createTestRecord(t),
|
||||||
|
setupMock: func(mp *MockPublisher) {
|
||||||
|
mp.On("Publish", adapter.RecordTopic, mock.AnythingOfType("[]*message.Message")).
|
||||||
|
Return(errors.New("record publish failed")).
|
||||||
|
Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "record publish failed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil Record应该失败",
|
||||||
|
record: nil,
|
||||||
|
setupMock: func(_ *MockPublisher) {
|
||||||
|
// nil record不会调用Publish,因为会在之前失败
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
tt.setupMock(mockPublisher)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
err := client.RecordPublish(tt.record)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errContains != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errContains)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_Close(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupMock func(*MockPublisher)
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "成功关闭",
|
||||||
|
setupMock: func(mp *MockPublisher) {
|
||||||
|
mp.On("Close").Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "关闭失败",
|
||||||
|
setupMock: func(mp *MockPublisher) {
|
||||||
|
mp.On("Close").Return(errors.New("close failed")).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "close failed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
tt.setupMock(mockPublisher)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
err := client.Close()
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errContains != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errContains)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_MessageContent(t *testing.T) {
|
||||||
|
// 测试发布的消息内容是否正确
|
||||||
|
t.Run("Operation消息内容验证", func(t *testing.T) { //nolint:dupl // 测试代码中的重复模式是合理的
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
operation := createTestOperation(t)
|
||||||
|
|
||||||
|
// 捕获发布的消息
|
||||||
|
var capturedMessages []*message.Message
|
||||||
|
mockPublisher.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).
|
||||||
|
Run(func(args mock.Arguments) {
|
||||||
|
messages, ok := args.Get(1).([]*message.Message)
|
||||||
|
if ok {
|
||||||
|
capturedMessages = messages
|
||||||
|
}
|
||||||
|
}).Return(nil).Once()
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
err := client.OperationPublish(operation)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证消息内容
|
||||||
|
require.Len(t, capturedMessages, 1)
|
||||||
|
msg := capturedMessages[0]
|
||||||
|
assert.Equal(t, operation.Key(), msg.UUID)
|
||||||
|
assert.NotEmpty(t, msg.Payload)
|
||||||
|
|
||||||
|
// 验证是Envelope格式,可以反序列化
|
||||||
|
unmarshaledOp, err := model.UnmarshalOperation(msg.Payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, operation.OpID, unmarshaledOp.OpID)
|
||||||
|
|
||||||
|
// 验证签名
|
||||||
|
verifyConfig := model.NewSM2VerifyConfig(publicKey)
|
||||||
|
verifiedEnv, err := model.VerifyEnvelopeWithConfig(msg.Payload, verifyConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, verifiedEnv)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Record消息内容验证", func(t *testing.T) { //nolint:dupl // 测试代码中的重复模式是合理的
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
record := createTestRecord(t)
|
||||||
|
|
||||||
|
// 捕获发布的消息
|
||||||
|
var capturedMessages []*message.Message
|
||||||
|
mockPublisher.On("Publish", adapter.RecordTopic, mock.AnythingOfType("[]*message.Message")).
|
||||||
|
Run(func(args mock.Arguments) {
|
||||||
|
messages, ok := args.Get(1).([]*message.Message)
|
||||||
|
if ok {
|
||||||
|
capturedMessages = messages
|
||||||
|
}
|
||||||
|
}).Return(nil).Once()
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
err := client.RecordPublish(record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证消息内容
|
||||||
|
require.Len(t, capturedMessages, 1)
|
||||||
|
msg := capturedMessages[0]
|
||||||
|
assert.Equal(t, record.Key(), msg.UUID)
|
||||||
|
assert.NotEmpty(t, msg.Payload)
|
||||||
|
|
||||||
|
// 验证是Envelope格式,可以反序列化
|
||||||
|
unmarshaledRecord, err := model.UnmarshalRecord(msg.Payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, record.ID, unmarshaledRecord.ID)
|
||||||
|
|
||||||
|
// 验证签名
|
||||||
|
verifyConfig := model.NewSM2VerifyConfig(publicKey)
|
||||||
|
verifiedEnv, err := model.VerifyEnvelopeWithConfig(msg.Payload, verifyConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, verifiedEnv)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_ConcurrentPublish(t *testing.T) {
|
||||||
|
// 测试并发发布
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
|
||||||
|
// 设置期望的调用次数
|
||||||
|
publishCount := 100
|
||||||
|
mockPublisher.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).
|
||||||
|
Return(nil).Times(publishCount)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
// 并发发布
|
||||||
|
errChan := make(chan error, publishCount)
|
||||||
|
for i := range publishCount {
|
||||||
|
go func(id int) {
|
||||||
|
//nolint:testifylint // 在goroutine中创建测试数据,使用panic处理错误
|
||||||
|
operation := createTestOperationWithID(nil, fmt.Sprintf("concurrent-test-%d", id))
|
||||||
|
errChan <- client.OperationPublish(operation)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集结果
|
||||||
|
for range publishCount {
|
||||||
|
err := <-errChan
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_EdgeCases(t *testing.T) {
|
||||||
|
t.Run("发布大型Operation", func(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
mockPublisher.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).
|
||||||
|
Return(nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
// 创建包含大量数据的Operation
|
||||||
|
operation := createTestOperation(t)
|
||||||
|
operation.OpActor = string(make([]byte, 1000)) // 1KB数据
|
||||||
|
|
||||||
|
err := client.OperationPublish(operation)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("发布大型Record", func(t *testing.T) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
mockPublisher.On("Publish", adapter.RecordTopic, mock.AnythingOfType("[]*message.Message")).Return(nil).Once()
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
// 创建包含大量数据的Record
|
||||||
|
record := createTestRecord(t)
|
||||||
|
record.WithExtra(make([]byte, 500)) // 500字节的额外数据
|
||||||
|
|
||||||
|
err := client.RecordPublish(record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_Integration(t *testing.T) {
|
||||||
|
// 集成测试 - 测试完整的工作流程
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(t)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
|
||||||
|
// 设置期望:发布Operation -> 发布Record -> 关闭
|
||||||
|
mockPublisher.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).Return(nil).Once()
|
||||||
|
mockPublisher.On("Publish", adapter.RecordTopic, mock.AnythingOfType("[]*message.Message")).Return(nil).Once()
|
||||||
|
mockPublisher.On("Close").Return(nil).Once()
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
|
||||||
|
// 发布Operation
|
||||||
|
operation := createTestOperation(t)
|
||||||
|
err := client.OperationPublish(operation)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 发布Record
|
||||||
|
record := createTestRecord(t)
|
||||||
|
err = client.RecordPublish(record)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 关闭客户端
|
||||||
|
err = client.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockPublisher.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 性能基准测试.
|
||||||
|
func BenchmarkClient_OperationPublish(b *testing.B) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(b)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
mockPublisher.On("Publish", adapter.OperationTopic, mock.AnythingOfType("[]*message.Message")).Return(nil)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
operation := createTestOperation(b)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
err := client.OperationPublish(operation)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkClient_RecordPublish(b *testing.B) {
|
||||||
|
mockPublisher := &MockPublisher{}
|
||||||
|
testLogger := logger.NewLogger(logr.Discard())
|
||||||
|
privateKey, publicKey := generateTestKeys(b)
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
mockPublisher.On("Publish", adapter.RecordTopic, mock.AnythingOfType("[]*message.Message")).Return(nil)
|
||||||
|
|
||||||
|
client := highclient.NewClient(mockPublisher, testLogger, config)
|
||||||
|
record := createTestRecord(b)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
err := client.RecordPublish(record)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试辅助函数.
|
||||||
|
func createTestOperation(t testing.TB) *model.Operation {
|
||||||
|
return createTestOperationWithID(t, "test-operation-001")
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestOperationWithID(t testing.TB, id string) *model.Operation {
|
||||||
|
// 在并发测试中,t可能为nil,这是正常的
|
||||||
|
errorHandler := func(err error) {
|
||||||
|
if t != nil {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
operation, err := model.NewFullOperation(
|
||||||
|
model.OpSourceDOIP,
|
||||||
|
model.OpTypeRetrieve,
|
||||||
|
"test-prefix",
|
||||||
|
"test-repo",
|
||||||
|
"test-prefix/test-repo/test-object",
|
||||||
|
"test-producer-id",
|
||||||
|
"test-actor",
|
||||||
|
"test request body",
|
||||||
|
"test response body",
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
errorHandler(err)
|
||||||
|
operation.OpID = id // 设置自定义ID
|
||||||
|
return operation
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestRecord(t testing.TB) *model.Record {
|
||||||
|
record, err := model.NewFullRecord(
|
||||||
|
"test-prefix",
|
||||||
|
"test-producer-id",
|
||||||
|
time.Now(),
|
||||||
|
"test-operator",
|
||||||
|
[]byte("test extra data"),
|
||||||
|
"test-type",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return record
|
||||||
|
}
|
||||||
183
api/logger/adapter.go
Normal file
183
api/logger/adapter.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logger 定义项目的日志接口,基于logr但提供更清晰的抽象.
|
||||||
|
// 支持结构化日志和上下文感知的日志记录.
|
||||||
|
type Logger interface {
|
||||||
|
// Context-aware structured logging methods (推荐使用)
|
||||||
|
DebugContext(ctx context.Context, msg string, args ...any)
|
||||||
|
InfoContext(ctx context.Context, msg string, args ...any)
|
||||||
|
WarnContext(ctx context.Context, msg string, args ...any)
|
||||||
|
ErrorContext(ctx context.Context, msg string, args ...any)
|
||||||
|
|
||||||
|
// Non-context structured logging methods (用于适配器和内部组件)
|
||||||
|
Debug(msg string, args ...any)
|
||||||
|
Info(msg string, args ...any)
|
||||||
|
Warn(msg string, args ...any)
|
||||||
|
Error(msg string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogger 创建一个基于logr的Logger实现.
|
||||||
|
func NewLogger(logger logr.Logger) Logger {
|
||||||
|
return &logrAdapter{
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultLogger 创建一个默认的Logger实现.
|
||||||
|
// 注意:logr需要显式提供一个LogSink实现,这里返回一个discard logger.
|
||||||
|
func NewDefaultLogger() Logger {
|
||||||
|
return &logrAdapter{
|
||||||
|
logger: logr.Discard(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NopLogger 空操作日志器实现,所有日志方法都不执行任何操作.
|
||||||
|
// 适用于不需要日志输出的场景,如测试或性能敏感的场景.
|
||||||
|
type NopLogger struct{}
|
||||||
|
|
||||||
|
// NewNopLogger 创建新的空操作日志器.
|
||||||
|
func NewNopLogger() *NopLogger {
|
||||||
|
return &NopLogger{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NopLogger) DebugContext(_ context.Context, _ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) InfoContext(_ context.Context, _ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) WarnContext(_ context.Context, _ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) ErrorContext(_ context.Context, _ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) Debug(_ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) Info(_ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) Warn(_ string, _ ...any) {}
|
||||||
|
func (n *NopLogger) Error(_ string, _ ...any) {}
|
||||||
|
|
||||||
|
// 全局日志器相关变量.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // 全局日志器是必要的,用于提供便捷的日志访问接口
|
||||||
|
var (
|
||||||
|
globalLogger Logger = NewNopLogger() // 默认使用 NopLogger
|
||||||
|
globalLoggerLock sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetGlobalLogger 设置全局日志器.
|
||||||
|
// 线程安全,可以在程序启动时调用以设置全局日志器.
|
||||||
|
func SetGlobalLogger(logger Logger) {
|
||||||
|
if logger == nil {
|
||||||
|
logger = NewNopLogger()
|
||||||
|
}
|
||||||
|
globalLoggerLock.Lock()
|
||||||
|
defer globalLoggerLock.Unlock()
|
||||||
|
globalLogger = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalLogger 获取全局日志器.
|
||||||
|
// 线程安全,返回当前设置的全局日志器,如果未设置则返回 NopLogger.
|
||||||
|
func GetGlobalLogger() Logger {
|
||||||
|
globalLoggerLock.RLock()
|
||||||
|
defer globalLoggerLock.RUnlock()
|
||||||
|
return globalLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// logrAdapter 是Logger接口的logr实现.
|
||||||
|
type logrAdapter struct {
|
||||||
|
logger logr.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertArgs 将args转换为logr的key-value对格式.
|
||||||
|
// logr要求key-value成对出现,如果args是奇数个,最后一个会作为单独的value.
|
||||||
|
func convertArgs(args ...any) []any {
|
||||||
|
// 如果args已经是成对的key-value格式,直接返回
|
||||||
|
if len(args)%2 == 0 {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
// 如果不是成对的,可能需要特殊处理
|
||||||
|
// 这里我们假设调用者传入的是key-value对
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) DebugContext(ctx context.Context, msg string, args ...any) {
|
||||||
|
_ = ctx // 保持接口兼容性,logr目前不支持context
|
||||||
|
l.logger.V(1).Info(msg, convertArgs(args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) InfoContext(ctx context.Context, msg string, args ...any) {
|
||||||
|
_ = ctx // 保持接口兼容性,logr目前不支持context
|
||||||
|
l.logger.Info(msg, convertArgs(args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) WarnContext(ctx context.Context, msg string, args ...any) {
|
||||||
|
_ = ctx // 保持接口兼容性,logr目前不支持context
|
||||||
|
// logr没有Warn级别,使用Info但标记为warning
|
||||||
|
kv := convertArgs(args...)
|
||||||
|
kv = append(kv, "level", "warning")
|
||||||
|
l.logger.Info(msg, kv...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) ErrorContext(ctx context.Context, msg string, args ...any) {
|
||||||
|
_ = ctx // 保持接口兼容性,logr目前不支持context
|
||||||
|
// 尝试从args中提取error
|
||||||
|
var err error
|
||||||
|
kvArgs := make([]any, 0, len(args))
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
if i+1 < len(args) && args[i] == "error" {
|
||||||
|
if e, ok := args[i+1].(error); ok {
|
||||||
|
err = e
|
||||||
|
i++ // 跳过error值
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kvArgs = append(kvArgs, args[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error(err, msg, convertArgs(kvArgs...)...)
|
||||||
|
} else {
|
||||||
|
// 如果没有error,使用Info但标记为error级别
|
||||||
|
kvArgs = append(kvArgs, "level", "error")
|
||||||
|
l.logger.Info(msg, convertArgs(kvArgs...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) Debug(msg string, args ...any) {
|
||||||
|
l.logger.V(1).Info(msg, convertArgs(args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) Info(msg string, args ...any) {
|
||||||
|
l.logger.Info(msg, convertArgs(args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) Warn(msg string, args ...any) {
|
||||||
|
// logr没有Warn级别,使用Info但标记为warning
|
||||||
|
kv := convertArgs(args...)
|
||||||
|
kv = append(kv, "level", "warning")
|
||||||
|
l.logger.Info(msg, kv...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logrAdapter) Error(msg string, args ...any) {
|
||||||
|
// 尝试从args中提取error
|
||||||
|
var err error
|
||||||
|
kvArgs := make([]any, 0, len(args))
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
if i+1 < len(args) && args[i] == "error" {
|
||||||
|
if e, ok := args[i+1].(error); ok {
|
||||||
|
err = e
|
||||||
|
i++ // 跳过error值
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kvArgs = append(kvArgs, args[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error(err, msg, convertArgs(kvArgs...)...)
|
||||||
|
} else {
|
||||||
|
// 如果没有error,使用Info但标记为error级别
|
||||||
|
kvArgs = append(kvArgs, "level", "error")
|
||||||
|
l.logger.Info(msg, convertArgs(kvArgs...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
255
api/logger/adapter_test.go
Normal file
255
api/logger/adapter_test.go
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
package logger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
result := logger.NewLogger(l)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDefaultLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := logger.NewDefaultLogger()
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewNopLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := logger.NewNopLogger()
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopLogger_Methods(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
nop := logger.NewNopLogger()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// All methods should not panic
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
nop.DebugContext(ctx, "test")
|
||||||
|
nop.InfoContext(ctx, "test")
|
||||||
|
nop.WarnContext(ctx, "test")
|
||||||
|
nop.ErrorContext(ctx, "test")
|
||||||
|
nop.Debug("test")
|
||||||
|
nop.Info("test")
|
||||||
|
nop.Warn("test")
|
||||||
|
nop.Error("test")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_DebugContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.DebugContext(ctx, "debug message", "key", "value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_InfoContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.InfoContext(ctx, "info message", "key", "value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_WarnContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.WarnContext(ctx, "warn message", "key", "value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_ErrorContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with error",
|
||||||
|
args: []any{"error", errors.New("test error"), "key", "value"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without error",
|
||||||
|
args: []any{"key", "value"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.ErrorContext(ctx, "error message", tt.args...)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_Debug(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Debug("debug message", "key", "value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_Info(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Info("info message", "key", "value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_Warn(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Warn("warn message", "key", "value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_Error(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := logr.Discard()
|
||||||
|
adapter := logger.NewLogger(l)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with error",
|
||||||
|
args: []any{"error", errors.New("test error"), "key", "value"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without error",
|
||||||
|
args: []any{"key", "value"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "odd number of args",
|
||||||
|
args: []any{"key", "value", "extra"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Error("error message", tt.args...)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetGlobalLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := logger.GetGlobalLogger()
|
||||||
|
defer logger.SetGlobalLogger(original)
|
||||||
|
|
||||||
|
newLogger := logger.NewNopLogger()
|
||||||
|
logger.SetGlobalLogger(newLogger)
|
||||||
|
|
||||||
|
result := logger.GetGlobalLogger()
|
||||||
|
assert.Equal(t, newLogger, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetGlobalLogger_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := logger.GetGlobalLogger()
|
||||||
|
defer logger.SetGlobalLogger(original)
|
||||||
|
|
||||||
|
logger.SetGlobalLogger(nil)
|
||||||
|
result := logger.GetGlobalLogger()
|
||||||
|
assert.NotNil(t, result) // Should be NopLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetGlobalLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := logger.GetGlobalLogger()
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalLogger_ConcurrentAccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := logger.GetGlobalLogger()
|
||||||
|
defer logger.SetGlobalLogger(original)
|
||||||
|
|
||||||
|
// Test concurrent reads
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for range 10 {
|
||||||
|
go func() {
|
||||||
|
_ = logger.GetGlobalLogger()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test concurrent writes
|
||||||
|
newLogger := logger.NewNopLogger()
|
||||||
|
for range 5 {
|
||||||
|
go func() {
|
||||||
|
logger.SetGlobalLogger(newLogger)
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for range 5 {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
result := logger.GetGlobalLogger()
|
||||||
|
require.NotNil(t, result)
|
||||||
|
}
|
||||||
207
api/model/config_signer.go
Normal file
207
api/model/config_signer.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/crpt/go-crpt"
|
||||||
|
_ "github.com/crpt/go-crpt/ed25519" // 注册 Ed25519
|
||||||
|
_ "github.com/crpt/go-crpt/sm2" // 注册 SM2
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigSigner 基于配置的通用签名器
|
||||||
|
// 根据 CryptoConfig 自动使用对应的签名算法
|
||||||
|
type ConfigSigner struct {
|
||||||
|
privateKey []byte // 私钥(序列化格式)
|
||||||
|
publicKey []byte // 公钥(序列化格式)
|
||||||
|
config *CryptoConfig // 密码学配置
|
||||||
|
privKey crpt.PrivateKey // 解析后的私钥
|
||||||
|
pubKey crpt.PublicKey // 解析后的公钥
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigSigner 创建基于配置的签名器
|
||||||
|
// 如果 config 为 nil,则使用全局配置
|
||||||
|
func NewConfigSigner(privateKey, publicKey []byte, config *CryptoConfig) (*ConfigSigner, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = GetGlobalCryptoConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating ConfigSigner",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
"privateKeyLength", len(privateKey),
|
||||||
|
"publicKeyLength", len(publicKey),
|
||||||
|
)
|
||||||
|
|
||||||
|
signer := &ConfigSigner{
|
||||||
|
privateKey: privateKey,
|
||||||
|
publicKey: publicKey,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 延迟解析密钥,只在需要时解析
|
||||||
|
// 这样可以避免初始化顺序问题
|
||||||
|
|
||||||
|
log.Debug("ConfigSigner created successfully",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
return signer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultSigner 创建使用默认 SM2 算法的签名器
|
||||||
|
// 注意:总是使用 SM2,不受全局配置影响
|
||||||
|
func NewDefaultSigner(privateKey, publicKey []byte) (*ConfigSigner, error) {
|
||||||
|
sm2Config := &CryptoConfig{
|
||||||
|
SignatureAlgorithm: SM2Algorithm,
|
||||||
|
}
|
||||||
|
return NewConfigSigner(privateKey, publicKey, sm2Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign 对数据进行签名
|
||||||
|
func (s *ConfigSigner) Sign(data []byte) ([]byte, error) {
|
||||||
|
if len(s.privateKey) == 0 {
|
||||||
|
return nil, fmt.Errorf("private key is not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Signing with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 根据算法类型使用对应的方法
|
||||||
|
switch s.config.SignatureAlgorithm {
|
||||||
|
case SM2Algorithm:
|
||||||
|
// SM2 使用现有的 ComputeSignature 函数(兼容 DER 格式)
|
||||||
|
signature, err := ComputeSignature(data, s.privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign with SM2",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Debug("Signed successfully with SM2",
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// 其他算法使用 crpt 通用接口
|
||||||
|
// 懒加载:解析私钥
|
||||||
|
if s.privKey == nil {
|
||||||
|
keyType, err := s.config.SignatureAlgorithm.toKeyType()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, err := crpt.PrivateKeyFromBytes(keyType, s.privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to parse private key",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||||
|
}
|
||||||
|
s.privKey = privKey
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := crpt.SignMessage(s.privKey, data, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to sign: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Signed successfully with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 验证签名
|
||||||
|
func (s *ConfigSigner) Verify(data, signature []byte) (bool, error) {
|
||||||
|
if len(s.publicKey) == 0 {
|
||||||
|
return false, fmt.Errorf("public key is not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 根据算法类型使用对应的方法
|
||||||
|
switch s.config.SignatureAlgorithm {
|
||||||
|
case SM2Algorithm:
|
||||||
|
// SM2 使用现有的 VerifySignature 函数(兼容 DER 格式)
|
||||||
|
ok, err := VerifySignature(data, s.publicKey, signature)
|
||||||
|
if err != nil {
|
||||||
|
// VerifySignature 在验证失败时也返回错误,需要判断错误类型
|
||||||
|
// 如果是 "signature verification failed",则返回 (false, nil)
|
||||||
|
if ok == false {
|
||||||
|
// 验证失败(不是异常)
|
||||||
|
log.Warn("Verification failed with SM2")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
// 其他错误(如解析错误)
|
||||||
|
log.Error("Failed to verify with SM2", "error", err)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
log.Debug("Verified successfully with SM2")
|
||||||
|
return true, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// 其他算法使用 crpt 通用接口
|
||||||
|
// 懒加载:解析公钥
|
||||||
|
if s.pubKey == nil {
|
||||||
|
keyType, err := s.config.SignatureAlgorithm.toKeyType()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey, err := crpt.PublicKeyFromBytes(keyType, s.publicKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to parse public key",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return false, fmt.Errorf("failed to parse public key: %w", err)
|
||||||
|
}
|
||||||
|
s.pubKey = pubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := crpt.VerifyMessage(s.pubKey, data, crpt.Signature(signature), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to verify with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return false, fmt.Errorf("failed to verify: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
log.Debug("Verified successfully with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warn("Verification failed with ConfigSigner",
|
||||||
|
"algorithm", s.config.SignatureAlgorithm,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAlgorithm 获取签名器使用的算法
|
||||||
|
func (s *ConfigSigner) GetAlgorithm() SignatureAlgorithm {
|
||||||
|
return s.config.SignatureAlgorithm
|
||||||
|
}
|
||||||
158
api/model/config_signer_test.go
Normal file
158
api/model/config_signer_test.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "github.com/crpt/go-crpt/sm2" // 确保 SM2 已注册
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewConfigSigner_SM2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成 SM2 密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名器
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.SM2Algorithm,
|
||||||
|
}
|
||||||
|
signer, err := model.NewConfigSigner(privateKeyDER, publicKeyDER, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, signer)
|
||||||
|
assert.Equal(t, model.SM2Algorithm, signer.GetAlgorithm())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDefaultSigner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成 SM2 密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建默认签名器(应该使用 SM2)
|
||||||
|
signer, err := model.NewDefaultSigner(privateKeyDER, publicKeyDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, signer)
|
||||||
|
assert.Equal(t, model.SM2Algorithm, signer.GetAlgorithm())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSigner_SignAndVerify_SM2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名器
|
||||||
|
signer, err := model.NewDefaultSigner(privateKeyDER, publicKeyDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 签名
|
||||||
|
data := []byte("test data for ConfigSigner")
|
||||||
|
signature, err := signer.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// 验证
|
||||||
|
ok, err := signer.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// 验证错误数据
|
||||||
|
wrongData := []byte("wrong data")
|
||||||
|
ok, err = signer.Verify(wrongData, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSigner_SignAndVerify_Ed25519(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成 Ed25519 密钥对
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||||
|
}
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privateKeyDER, err := keyPair.MarshalPrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := keyPair.MarshalPublicKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名器
|
||||||
|
signer, err := model.NewConfigSigner(privateKeyDER, publicKeyDER, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 签名
|
||||||
|
data := []byte("test data for Ed25519")
|
||||||
|
signature, err := signer.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// 验证
|
||||||
|
ok, err := signer.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSigner_CompatibleWithSM2Signer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 使用 ConfigSigner 签名
|
||||||
|
configSigner, err := model.NewDefaultSigner(privateKeyDER, publicKeyDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data := []byte("test data")
|
||||||
|
signature1, err := configSigner.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 使用 SM2Signer 验证
|
||||||
|
sm2Signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||||
|
ok1, err := sm2Signer.Verify(data, signature1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok1, "SM2Signer should verify ConfigSigner's signature")
|
||||||
|
|
||||||
|
// 使用 SM2Signer 签名
|
||||||
|
signature2, err := sm2Signer.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 使用 ConfigSigner 验证
|
||||||
|
ok2, err := configSigner.Verify(data, signature2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok2, "ConfigSigner should verify SM2Signer's signature")
|
||||||
|
}
|
||||||
200
api/model/converter.go
Normal file
200
api/model/converter.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FromProtobuf 将protobuf的OperationData转换为model.Operation.
|
||||||
|
func FromProtobuf(pbOp *pb.OperationData) (*Operation, error) {
|
||||||
|
if pbOp == nil {
|
||||||
|
return nil, errors.New("protobuf operation data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换时间戳
|
||||||
|
if pbOp.GetTimestamp() == nil {
|
||||||
|
return nil, errors.New("timestamp is required")
|
||||||
|
}
|
||||||
|
timestamp := pbOp.GetTimestamp().AsTime()
|
||||||
|
|
||||||
|
// 构建Operation
|
||||||
|
operation := &Operation{
|
||||||
|
OpID: pbOp.GetOpId(),
|
||||||
|
Timestamp: timestamp,
|
||||||
|
OpSource: Source(pbOp.GetOpSource()),
|
||||||
|
OpType: Type(pbOp.GetOpType()),
|
||||||
|
DoPrefix: pbOp.GetDoPrefix(),
|
||||||
|
DoRepository: pbOp.GetDoRepository(),
|
||||||
|
Doid: pbOp.GetDoid(),
|
||||||
|
ProducerID: pbOp.GetProducerId(),
|
||||||
|
OpActor: pbOp.GetOpActor(),
|
||||||
|
// OpAlgorithm和OpMetaHash字段已移除,固定使用Sha256Simd,哈希值由Envelope的OriginalHash提供
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理可选的哈希字段
|
||||||
|
if reqHash := pbOp.GetRequestBodyHash(); reqHash != "" {
|
||||||
|
operation.RequestBodyHash = &reqHash
|
||||||
|
}
|
||||||
|
if respHash := pbOp.GetResponseBodyHash(); respHash != "" {
|
||||||
|
operation.ResponseBodyHash = &respHash
|
||||||
|
}
|
||||||
|
|
||||||
|
return operation, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToProtobuf 将model.Operation转换为protobuf的OperationData.
|
||||||
|
func ToProtobuf(op *Operation) (*pb.OperationData, error) {
|
||||||
|
if op == nil {
|
||||||
|
return nil, errors.New("operation is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换时间戳
|
||||||
|
timestamp := timestamppb.New(op.Timestamp)
|
||||||
|
|
||||||
|
pbOp := &pb.OperationData{
|
||||||
|
OpId: op.OpID,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
OpSource: string(op.OpSource),
|
||||||
|
OpType: string(op.OpType),
|
||||||
|
DoPrefix: op.DoPrefix,
|
||||||
|
DoRepository: op.DoRepository,
|
||||||
|
Doid: op.Doid,
|
||||||
|
ProducerId: op.ProducerID,
|
||||||
|
OpActor: op.OpActor,
|
||||||
|
// OpAlgorithm、OpMetaHash和OpHash字段已移除,固定使用Sha256Simd,哈希值由Envelope的OriginalHash提供
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理可选的哈希字段
|
||||||
|
if op.RequestBodyHash != nil {
|
||||||
|
pbOp.RequestBodyHash = *op.RequestBodyHash
|
||||||
|
}
|
||||||
|
if op.ResponseBodyHash != nil {
|
||||||
|
pbOp.ResponseBodyHash = *op.ResponseBodyHash
|
||||||
|
}
|
||||||
|
|
||||||
|
return pbOp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromProtobufValidationResult 将protobuf的ValidationStreamRes转换为model.ValidationResult.
|
||||||
|
func FromProtobufValidationResult(pbRes *pb.ValidationStreamRes) (*ValidationResult, error) {
|
||||||
|
if pbRes == nil {
|
||||||
|
return nil, errors.New("protobuf validation result is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &ValidationResult{
|
||||||
|
Code: pbRes.GetCode(),
|
||||||
|
Msg: pbRes.GetMsg(),
|
||||||
|
Progress: pbRes.GetProgress(),
|
||||||
|
Proof: ProofFromProtobuf(pbRes.GetProof()), // 取证证明
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有操作数据,则转换
|
||||||
|
if pbRes.GetData() != nil {
|
||||||
|
op, err := FromProtobuf(pbRes.GetData())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert operation data: %w", err)
|
||||||
|
}
|
||||||
|
result.Data = op
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordFromProtobuf 将protobuf的RecordData转换为model.Record.
|
||||||
|
func RecordFromProtobuf(pbRec *pb.RecordData) (*Record, error) {
|
||||||
|
if pbRec == nil {
|
||||||
|
return nil, errors.New("protobuf record data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建Record
|
||||||
|
record := &Record{
|
||||||
|
ID: pbRec.GetId(),
|
||||||
|
DoPrefix: pbRec.GetDoPrefix(),
|
||||||
|
ProducerID: pbRec.GetProducerId(),
|
||||||
|
Operator: pbRec.GetOperator(),
|
||||||
|
Extra: pbRec.GetExtra(),
|
||||||
|
RCType: pbRec.GetRcType(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换时间戳
|
||||||
|
if pbRec.GetTimestamp() != nil {
|
||||||
|
record.Timestamp = pbRec.GetTimestamp().AsTime()
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordToProtobuf 将model.Record转换为protobuf的RecordData.
|
||||||
|
func RecordToProtobuf(rec *Record) (*pb.RecordData, error) {
|
||||||
|
if rec == nil {
|
||||||
|
return nil, errors.New("record is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换时间戳
|
||||||
|
timestamp := timestamppb.New(rec.Timestamp)
|
||||||
|
|
||||||
|
pbRec := &pb.RecordData{
|
||||||
|
Id: rec.ID,
|
||||||
|
DoPrefix: rec.DoPrefix,
|
||||||
|
ProducerId: rec.ProducerID,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
Operator: rec.Operator,
|
||||||
|
Extra: rec.Extra,
|
||||||
|
RcType: rec.RCType,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pbRec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordValidationResult 包装记录验证的流式响应结果.
|
||||||
|
type RecordValidationResult struct {
|
||||||
|
Code int32 // 状态码(100处理中,200完成,500失败)
|
||||||
|
Msg string // 消息描述
|
||||||
|
Progress string // 当前进度(比如 "50%")
|
||||||
|
Data *Record // 最终完成时返回的记录数据,过程中可为空
|
||||||
|
Proof *Proof // 取证证明(仅在完成时返回)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProcessing 判断是否正在处理中.
|
||||||
|
func (r *RecordValidationResult) IsProcessing() bool {
|
||||||
|
return r.Code == ValidationCodeProcessing
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCompleted 判断是否已完成.
|
||||||
|
func (r *RecordValidationResult) IsCompleted() bool {
|
||||||
|
return r.Code == ValidationCodeCompleted
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFailed 判断是否失败.
|
||||||
|
func (r *RecordValidationResult) IsFailed() bool {
|
||||||
|
return r.Code >= ValidationCodeFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordFromProtobufValidationResult 将protobuf的RecordValidationStreamRes转换为model.RecordValidationResult.
|
||||||
|
func RecordFromProtobufValidationResult(pbRes *pb.RecordValidationStreamRes) (*RecordValidationResult, error) {
|
||||||
|
if pbRes == nil {
|
||||||
|
return nil, errors.New("protobuf record validation result is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &RecordValidationResult{
|
||||||
|
Code: pbRes.GetCode(),
|
||||||
|
Msg: pbRes.GetMsg(),
|
||||||
|
Progress: pbRes.GetProgress(),
|
||||||
|
Proof: ProofFromProtobuf(pbRes.GetProof()), // 取证证明
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有记录数据,则转换
|
||||||
|
if pbRes.GetResult() != nil {
|
||||||
|
rec, err := RecordFromProtobuf(pbRes.GetResult())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert record data: %w", err)
|
||||||
|
}
|
||||||
|
result.Data = rec
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
575
api/model/converter_test.go
Normal file
575
api/model/converter_test.go
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFromProtobuf_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result, err := model.FromProtobuf(nil)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "protobuf operation data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobuf_NoTimestamp(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbOp := &pb.OperationData{}
|
||||||
|
result, err := model.FromProtobuf(pbOp)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "timestamp is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobuf_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
pbOp := &pb.OperationData{
|
||||||
|
OpId: "op-123",
|
||||||
|
Timestamp: timestamppb.New(now),
|
||||||
|
OpSource: "IRP",
|
||||||
|
OpType: "OC_CREATE_HANDLE",
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobuf(pbOp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, "op-123", result.OpID)
|
||||||
|
assert.Equal(t, now.Unix(), result.Timestamp.Unix())
|
||||||
|
assert.Equal(t, model.Source("IRP"), result.OpSource)
|
||||||
|
assert.Equal(t, model.Type("OC_CREATE_HANDLE"), result.OpType)
|
||||||
|
assert.Equal(t, "test", result.DoPrefix)
|
||||||
|
assert.Equal(t, "repo", result.DoRepository)
|
||||||
|
assert.Equal(t, "test/repo/123", result.Doid)
|
||||||
|
assert.Equal(t, "producer-1", result.ProducerID)
|
||||||
|
assert.Equal(t, "actor-1", result.OpActor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobuf_WithHashes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
pbOp := &pb.OperationData{
|
||||||
|
OpId: "op-123",
|
||||||
|
Timestamp: timestamppb.New(now),
|
||||||
|
OpSource: "DOIP",
|
||||||
|
OpType: "Create",
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
RequestBodyHash: "req-hash",
|
||||||
|
ResponseBodyHash: "resp-hash",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobuf(pbOp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.NotNil(t, result.RequestBodyHash)
|
||||||
|
assert.Equal(t, "req-hash", *result.RequestBodyHash)
|
||||||
|
assert.NotNil(t, result.ResponseBodyHash)
|
||||||
|
assert.Equal(t, "resp-hash", *result.ResponseBodyHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobuf_EmptyHashes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
pbOp := &pb.OperationData{
|
||||||
|
OpId: "op-123",
|
||||||
|
Timestamp: timestamppb.New(now),
|
||||||
|
OpSource: "DOIP",
|
||||||
|
OpType: "Create",
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
RequestBodyHash: "",
|
||||||
|
ResponseBodyHash: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobuf(pbOp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Nil(t, result.RequestBodyHash)
|
||||||
|
assert.Nil(t, result.ResponseBodyHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtobuf_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result, err := model.ToProtobuf(nil)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "operation is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtobuf_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: now,
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.ToProtobuf(op)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, "op-123", result.GetOpId())
|
||||||
|
assert.Equal(t, now.Unix(), result.GetTimestamp().AsTime().Unix())
|
||||||
|
assert.Equal(t, "IRP", result.GetOpSource())
|
||||||
|
assert.Equal(t, "OC_CREATE_HANDLE", result.GetOpType())
|
||||||
|
assert.Equal(t, "test", result.GetDoPrefix())
|
||||||
|
assert.Equal(t, "repo", result.GetDoRepository())
|
||||||
|
assert.Equal(t, "test/repo/123", result.GetDoid())
|
||||||
|
assert.Equal(t, "producer-1", result.GetProducerId())
|
||||||
|
assert.Equal(t, "actor-1", result.GetOpActor())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtobuf_WithHashes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
reqHash := "req-hash"
|
||||||
|
respHash := "resp-hash"
|
||||||
|
now := time.Now()
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: now,
|
||||||
|
OpSource: model.OpSourceDOIP,
|
||||||
|
OpType: model.OpTypeCreate,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
RequestBodyHash: &reqHash,
|
||||||
|
ResponseBodyHash: &respHash,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.ToProtobuf(op)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, "req-hash", result.GetRequestBodyHash())
|
||||||
|
assert.Equal(t, "resp-hash", result.GetResponseBodyHash())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtobuf_WithoutHashes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: now,
|
||||||
|
OpSource: model.OpSourceDOIP,
|
||||||
|
OpType: model.OpTypeCreate,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.ToProtobuf(op)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Empty(t, result.GetRequestBodyHash())
|
||||||
|
assert.Empty(t, result.GetResponseBodyHash())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobufValidationResult_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result, err := model.FromProtobufValidationResult(nil)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "protobuf validation result is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobufValidationResult_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRes := &pb.ValidationStreamRes{
|
||||||
|
Code: 100,
|
||||||
|
Msg: "Processing",
|
||||||
|
Progress: "50%",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobufValidationResult(pbRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(100), result.Code)
|
||||||
|
assert.Equal(t, "Processing", result.Msg)
|
||||||
|
assert.Equal(t, "50%", result.Progress)
|
||||||
|
assert.Nil(t, result.Data)
|
||||||
|
assert.Nil(t, result.Proof)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobufValidationResult_WithProof(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRes := &pb.ValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Proof: &pb.Proof{
|
||||||
|
Sign: "test-signature",
|
||||||
|
ColItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "hash1", Left: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobufValidationResult(pbRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
assert.NotNil(t, result.Proof)
|
||||||
|
assert.Equal(t, "test-signature", result.Proof.Sign)
|
||||||
|
assert.Len(t, result.Proof.ColItems, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobufValidationResult_WithData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
pbRes := &pb.ValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Data: &pb.OperationData{
|
||||||
|
OpId: "op-123",
|
||||||
|
Timestamp: timestamppb.New(now),
|
||||||
|
OpSource: "IRP",
|
||||||
|
OpType: "OC_CREATE_HANDLE",
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobufValidationResult(pbRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
assert.NotNil(t, result.Data)
|
||||||
|
assert.Equal(t, "op-123", result.Data.OpID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromProtobufValidationResult_WithInvalidData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRes := &pb.ValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Data: &pb.OperationData{
|
||||||
|
// Missing timestamp
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.FromProtobufValidationResult(pbRes)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to convert operation data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobuf_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobuf(nil)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "protobuf record data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobuf_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
pbRec := &pb.RecordData{
|
||||||
|
Id: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
Timestamp: timestamppb.New(now),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra-data"),
|
||||||
|
RcType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobuf(pbRec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, "rec-123", result.ID)
|
||||||
|
assert.Equal(t, "test", result.DoPrefix)
|
||||||
|
assert.Equal(t, "producer-1", result.ProducerID)
|
||||||
|
assert.Equal(t, now.Unix(), result.Timestamp.Unix())
|
||||||
|
assert.Equal(t, "operator-1", result.Operator)
|
||||||
|
assert.Equal(t, []byte("extra-data"), result.Extra)
|
||||||
|
assert.Equal(t, "log", result.RCType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobuf_NoTimestamp(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRec := &pb.RecordData{
|
||||||
|
Id: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra-data"),
|
||||||
|
RcType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobuf(pbRec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, "rec-123", result.ID)
|
||||||
|
assert.True(t, result.Timestamp.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordToProtobuf_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result, err := model.RecordToProtobuf(nil)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "record is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordToProtobuf_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
rec := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: now,
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra-data"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordToProtobuf(rec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, "rec-123", result.GetId())
|
||||||
|
assert.Equal(t, "test", result.GetDoPrefix())
|
||||||
|
assert.Equal(t, "producer-1", result.GetProducerId())
|
||||||
|
assert.Equal(t, now.Unix(), result.GetTimestamp().AsTime().Unix())
|
||||||
|
assert.Equal(t, "operator-1", result.GetOperator())
|
||||||
|
assert.Equal(t, []byte("extra-data"), result.GetExtra())
|
||||||
|
assert.Equal(t, "log", result.GetRcType())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobufValidationResult_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobufValidationResult(nil)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "protobuf record validation result is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobufValidationResult_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRes := &pb.RecordValidationStreamRes{
|
||||||
|
Code: 100,
|
||||||
|
Msg: "Processing",
|
||||||
|
Progress: "50%",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(100), result.Code)
|
||||||
|
assert.Equal(t, "Processing", result.Msg)
|
||||||
|
assert.Equal(t, "50%", result.Progress)
|
||||||
|
assert.Nil(t, result.Data)
|
||||||
|
assert.Nil(t, result.Proof)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobufValidationResult_WithProof(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRes := &pb.RecordValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Proof: &pb.Proof{
|
||||||
|
Sign: "test-signature",
|
||||||
|
RawItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "hash1", Left: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
assert.NotNil(t, result.Proof)
|
||||||
|
assert.Equal(t, "test-signature", result.Proof.Sign)
|
||||||
|
assert.Len(t, result.Proof.RawItems, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobufValidationResult_WithData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
pbRes := &pb.RecordValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Result: &pb.RecordData{
|
||||||
|
Id: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
Timestamp: timestamppb.New(now),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra-data"),
|
||||||
|
RcType: "log",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
assert.NotNil(t, result.Data)
|
||||||
|
assert.Equal(t, "rec-123", result.Data.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordFromProtobufValidationResult_WithInvalidData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbRes := &pb.RecordValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Result: &pb.RecordData{
|
||||||
|
// Missing required fields to trigger error
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||||
|
// This should succeed even with empty RecordData
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTrip_Operation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
original := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: now,
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to protobuf
|
||||||
|
pbOp, err := model.ToProtobuf(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, pbOp)
|
||||||
|
|
||||||
|
// Convert back to model
|
||||||
|
result, err := model.FromProtobuf(pbOp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
// Verify round trip
|
||||||
|
assert.Equal(t, original.OpID, result.OpID)
|
||||||
|
assert.Equal(t, original.OpSource, result.OpSource)
|
||||||
|
assert.Equal(t, original.OpType, result.OpType)
|
||||||
|
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||||
|
assert.Equal(t, original.DoRepository, result.DoRepository)
|
||||||
|
assert.Equal(t, original.Doid, result.Doid)
|
||||||
|
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||||
|
assert.Equal(t, original.OpActor, result.OpActor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTrip_Record(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
original := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: now,
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra-data"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to protobuf
|
||||||
|
pbRec, err := model.RecordToProtobuf(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, pbRec)
|
||||||
|
|
||||||
|
// Convert back to model
|
||||||
|
result, err := model.RecordFromProtobuf(pbRec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
// Verify round trip
|
||||||
|
assert.Equal(t, original.ID, result.ID)
|
||||||
|
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||||
|
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||||
|
assert.Equal(t, original.Timestamp.Unix(), result.Timestamp.Unix())
|
||||||
|
assert.Equal(t, original.Operator, result.Operator)
|
||||||
|
assert.Equal(t, original.Extra, result.Extra)
|
||||||
|
assert.Equal(t, original.RCType, result.RCType)
|
||||||
|
}
|
||||||
310
api/model/crypto_config.go
Normal file
310
api/model/crypto_config.go
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/crpt/go-crpt"
|
||||||
|
_ "github.com/crpt/go-crpt/ed25519" // Import Ed25519
|
||||||
|
_ "github.com/crpt/go-crpt/sm2" // Import SM2
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignatureAlgorithm 定义支持的签名算法类型.
|
||||||
|
type SignatureAlgorithm string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SM2 国密SM2算法
|
||||||
|
SM2Algorithm SignatureAlgorithm = "sm2"
|
||||||
|
// Ed25519 Ed25519算法
|
||||||
|
Ed25519Algorithm SignatureAlgorithm = "ed25519"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CryptoConfig 密码学配置
|
||||||
|
type CryptoConfig struct {
|
||||||
|
// SignatureAlgorithm 签名算法类型
|
||||||
|
// SM2 会自动使用 SM3 哈希,Ed25519 会使用 SHA512 哈希
|
||||||
|
SignatureAlgorithm SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// 默认配置:使用 SM2(内部自动使用 SM3)
|
||||||
|
defaultConfig = &CryptoConfig{
|
||||||
|
SignatureAlgorithm: SM2Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全局配置
|
||||||
|
globalConfig *CryptoConfig
|
||||||
|
globalConfigMutex sync.RWMutex
|
||||||
|
|
||||||
|
// ErrUnsupportedAlgorithm 不支持的算法错误
|
||||||
|
ErrUnsupportedAlgorithm = errors.New("unsupported signature algorithm")
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// 自动初始化全局配置为 SM2
|
||||||
|
globalConfig = defaultConfig
|
||||||
|
logger.GetGlobalLogger().Debug("Crypto config initialized with default SM2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGlobalCryptoConfig 设置全局密码学配置
|
||||||
|
func SetGlobalCryptoConfig(config *CryptoConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return errors.New("config cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证配置
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("invalid config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
globalConfigMutex.Lock()
|
||||||
|
defer globalConfigMutex.Unlock()
|
||||||
|
|
||||||
|
globalConfig = config
|
||||||
|
logger.GetGlobalLogger().Info("Global crypto config updated",
|
||||||
|
"signatureAlgorithm", config.SignatureAlgorithm,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalCryptoConfig 获取全局密码学配置
|
||||||
|
func GetGlobalCryptoConfig() *CryptoConfig {
|
||||||
|
globalConfigMutex.RLock()
|
||||||
|
defer globalConfigMutex.RUnlock()
|
||||||
|
|
||||||
|
if globalConfig == nil {
|
||||||
|
return defaultConfig
|
||||||
|
}
|
||||||
|
return globalConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 验证配置是否有效
|
||||||
|
func (c *CryptoConfig) Validate() error {
|
||||||
|
// 验证签名算法
|
||||||
|
switch c.SignatureAlgorithm {
|
||||||
|
case SM2Algorithm, Ed25519Algorithm:
|
||||||
|
// 支持的算法
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, c.SignatureAlgorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toKeyType 将 SignatureAlgorithm 转换为 crpt.KeyType
|
||||||
|
func (a SignatureAlgorithm) toKeyType() (crpt.KeyType, error) {
|
||||||
|
switch a {
|
||||||
|
case SM2Algorithm:
|
||||||
|
return crpt.SM2, nil
|
||||||
|
case Ed25519Algorithm:
|
||||||
|
return crpt.Ed25519, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyPair 通用密钥对,支持多种算法
|
||||||
|
type KeyPair struct {
|
||||||
|
Public crpt.PublicKey `json:"publicKey"`
|
||||||
|
Private crpt.PrivateKey `json:"privateKey"`
|
||||||
|
Algorithm SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateKeyPair 根据配置生成密钥对
|
||||||
|
func GenerateKeyPair(config *CryptoConfig) (*KeyPair, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = GetGlobalCryptoConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Generating key pair",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
keyType, err := config.SignatureAlgorithm.toKeyType()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, priv, err := crpt.GenerateKey(keyType, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to generate key pair",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to generate %s key pair: %w", config.SignatureAlgorithm, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Key pair generated successfully",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &KeyPair{
|
||||||
|
Public: pub,
|
||||||
|
Private: priv,
|
||||||
|
Algorithm: config.SignatureAlgorithm,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign 使用密钥对签名数据
|
||||||
|
func (kp *KeyPair) Sign(data []byte, rand io.Reader) ([]byte, error) {
|
||||||
|
if rand == nil {
|
||||||
|
rand = defaultRand()
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Signing data",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
signature, err := crpt.SignMessage(kp.Private, data, rand, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign data",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to sign with %s: %w", kp.Algorithm, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Data signed successfully",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 使用公钥验证签名
|
||||||
|
func (kp *KeyPair) Verify(data, signature []byte) (bool, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying signature",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
|
||||||
|
ok, err := crpt.VerifyMessage(kp.Public, data, crpt.Signature(signature), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to verify signature",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return false, fmt.Errorf("failed to verify with %s: %w", kp.Algorithm, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
log.Debug("Signature verified successfully",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warn("Signature verification failed",
|
||||||
|
"algorithm", kp.Algorithm,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalPrivateKey 序列化私钥
|
||||||
|
func (kp *KeyPair) MarshalPrivateKey() ([]byte, error) {
|
||||||
|
if kp.Private == nil {
|
||||||
|
return nil, errors.New("private key is nil")
|
||||||
|
}
|
||||||
|
return kp.Private.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalPublicKey 序列化公钥
|
||||||
|
func (kp *KeyPair) MarshalPublicKey() ([]byte, error) {
|
||||||
|
if kp.Public == nil {
|
||||||
|
return nil, errors.New("public key is nil")
|
||||||
|
}
|
||||||
|
return kp.Public.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePrivateKey 解析私钥
|
||||||
|
func ParsePrivateKey(data []byte, algorithm SignatureAlgorithm) (crpt.PrivateKey, error) {
|
||||||
|
keyType, err := algorithm.toKeyType()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return crpt.PrivateKeyFromBytes(keyType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePublicKey 解析公钥
|
||||||
|
func ParsePublicKey(data []byte, algorithm SignatureAlgorithm) (crpt.PublicKey, error) {
|
||||||
|
keyType, err := algorithm.toKeyType()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return crpt.PublicKeyFromBytes(keyType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultRand 返回默认的随机数生成器
|
||||||
|
func defaultRand() io.Reader {
|
||||||
|
return rand.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignWithConfig 使用指定配置签名数据
|
||||||
|
func SignWithConfig(data, privateKeyDER []byte, config *CryptoConfig) ([]byte, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = GetGlobalCryptoConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Signing with config",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
privateKey, err := ParsePrivateKey(privateKeyDER, config.SignatureAlgorithm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := crpt.SignMessage(privateKey, data, rand.Reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to sign: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Signed with config successfully",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyWithConfig 使用指定配置验证签名
|
||||||
|
func VerifyWithConfig(data, publicKeyDER, signature []byte, config *CryptoConfig) (bool, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = GetGlobalCryptoConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying with config",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
publicKey, err := ParsePublicKey(publicKeyDER, config.SignatureAlgorithm)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to parse public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := crpt.VerifyMessage(publicKey, data, crpt.Signature(signature), nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to verify: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Verified with config",
|
||||||
|
"algorithm", config.SignatureAlgorithm,
|
||||||
|
"result", ok,
|
||||||
|
)
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
251
api/model/crypto_config_test.go
Normal file
251
api/model/crypto_config_test.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCryptoConfig_Validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *model.CryptoConfig
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid SM2 config",
|
||||||
|
config: &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.SM2Algorithm,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid Ed25519 config",
|
||||||
|
config: &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid signature algorithm",
|
||||||
|
config: &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: "rsa",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
err := tt.config.Validate()
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetGetGlobalCryptoConfig(t *testing.T) {
|
||||||
|
// 不使用 t.Parallel(),因为它修改全局状态
|
||||||
|
|
||||||
|
// 保存当前配置
|
||||||
|
original := model.GetGlobalCryptoConfig()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := model.SetGlobalCryptoConfig(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
retrieved := model.GetGlobalCryptoConfig()
|
||||||
|
assert.Equal(t, config.SignatureAlgorithm, retrieved.SignatureAlgorithm)
|
||||||
|
|
||||||
|
// 恢复原配置
|
||||||
|
_ = model.SetGlobalCryptoConfig(original)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateKeyPair_SM2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.SM2Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, keyPair)
|
||||||
|
assert.NotNil(t, keyPair.Public)
|
||||||
|
assert.NotNil(t, keyPair.Private)
|
||||||
|
assert.Equal(t, model.SM2Algorithm, keyPair.Algorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateKeyPair_Ed25519(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, keyPair)
|
||||||
|
assert.NotNil(t, keyPair.Public)
|
||||||
|
assert.NotNil(t, keyPair.Private)
|
||||||
|
assert.Equal(t, model.Ed25519Algorithm, keyPair.Algorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPair_SignAndVerify_SM2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.SM2Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data := []byte("test data for SM2 signing")
|
||||||
|
|
||||||
|
// Sign
|
||||||
|
signature, err := keyPair.Sign(data, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
ok, err := keyPair.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Verify with wrong data should fail
|
||||||
|
wrongData := []byte("wrong data")
|
||||||
|
ok, err = keyPair.Verify(wrongData, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPair_SignAndVerify_Ed25519(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data := []byte("test data for Ed25519 signing")
|
||||||
|
|
||||||
|
// Sign
|
||||||
|
signature, err := keyPair.Sign(data, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
ok, err := keyPair.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Verify with wrong data should fail
|
||||||
|
wrongData := []byte("wrong data")
|
||||||
|
ok, err = keyPair.Verify(wrongData, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPair_MarshalAndParse_SM2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: model.SM2Algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Marshal private key
|
||||||
|
privateKeyDER, err := keyPair.MarshalPrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, privateKeyDER)
|
||||||
|
|
||||||
|
// Marshal public key
|
||||||
|
publicKeyDER, err := keyPair.MarshalPublicKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, publicKeyDER)
|
||||||
|
|
||||||
|
// Parse keys back
|
||||||
|
parsedPriv, err := model.ParsePrivateKey(privateKeyDER, model.SM2Algorithm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsedPriv)
|
||||||
|
|
||||||
|
parsedPub, err := model.ParsePublicKey(publicKeyDER, model.SM2Algorithm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsedPub)
|
||||||
|
|
||||||
|
// Test sign/verify with parsed keys
|
||||||
|
data := []byte("test data")
|
||||||
|
signature, err := model.SignWithConfig(data, privateKeyDER, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ok, err := model.VerifyWithConfig(data, publicKeyDER, signature, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignWithConfig_And_VerifyWithConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
algorithm model.SignatureAlgorithm
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SM2",
|
||||||
|
algorithm: model.SM2Algorithm,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Ed25519",
|
||||||
|
algorithm: model.Ed25519Algorithm,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := &model.CryptoConfig{
|
||||||
|
SignatureAlgorithm: tt.algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate key pair
|
||||||
|
keyPair, err := model.GenerateKeyPair(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Marshal keys
|
||||||
|
privateKeyDER, err := keyPair.MarshalPrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := keyPair.MarshalPublicKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign
|
||||||
|
data := []byte("test data")
|
||||||
|
signature, err := model.SignWithConfig(data, privateKeyDER, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
ok, err := model.VerifyWithConfig(data, publicKeyDER, signature, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
501
api/model/envelope.go
Normal file
501
api/model/envelope.go
Normal file
@@ -0,0 +1,501 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Envelope 包装序列化后的数据,包含元信息和报文体。
|
||||||
|
// 用于 Trustlog 接口类型的序列化和反序列化。
|
||||||
|
type Envelope struct {
|
||||||
|
ProducerID string // 日志提供者ID
|
||||||
|
Signature []byte // 签名(根据客户端密钥与指定算法进行签名,二进制格式)
|
||||||
|
Body []byte // CBOR序列化的报文体
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnvelopeConfig 序列化配置。
|
||||||
|
type EnvelopeConfig struct {
|
||||||
|
Signer Signer // 签名器,用于签名和验签
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyConfig 验签配置。
|
||||||
|
type VerifyConfig struct {
|
||||||
|
Signer Signer // 签名器,用于验签
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvelopeConfig 创建Envelope配置。
|
||||||
|
func NewEnvelopeConfig(signer Signer) EnvelopeConfig {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new EnvelopeConfig",
|
||||||
|
"signerType", fmt.Sprintf("%T", signer),
|
||||||
|
)
|
||||||
|
return EnvelopeConfig{
|
||||||
|
Signer: signer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSM2EnvelopeConfig 创建使用SM2签名的Envelope配置。
|
||||||
|
// 便捷方法,用于快速创建SM2签名器配置。
|
||||||
|
func NewSM2EnvelopeConfig(privateKey, publicKey []byte) EnvelopeConfig {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new SM2 EnvelopeConfig",
|
||||||
|
"privateKeyLength", len(privateKey),
|
||||||
|
"publicKeyLength", len(publicKey),
|
||||||
|
)
|
||||||
|
return EnvelopeConfig{
|
||||||
|
Signer: NewSM2Signer(privateKey, publicKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewVerifyConfig 创建验签配置。
|
||||||
|
func NewVerifyConfig(signer Signer) VerifyConfig {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new VerifyConfig",
|
||||||
|
"signerType", fmt.Sprintf("%T", signer),
|
||||||
|
)
|
||||||
|
return VerifyConfig{
|
||||||
|
Signer: signer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSM2VerifyConfig 创建使用SM2签名的验签配置。
|
||||||
|
// 便捷方法,用于快速创建SM2签名器验签配置。
|
||||||
|
// 注意:验签只需要公钥,但SM2Signer需要同时提供私钥和公钥(私钥可以为空)。
|
||||||
|
func NewSM2VerifyConfig(publicKey []byte) VerifyConfig {
|
||||||
|
return VerifyConfig{
|
||||||
|
Signer: NewSM2Signer(nil, publicKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== Envelope 序列化/反序列化 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// MarshalEnvelope 将 Envelope 序列化为 TLV 格式(Varint长度编码)。
|
||||||
|
// 格式:[字段1长度][字段1值:producerID][字段2长度][字段2值:签名][字段3长度][字段3值:CBOR报文体]。
|
||||||
|
func MarshalEnvelope(env *Envelope) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Marshaling envelope to TLV format")
|
||||||
|
if env == nil {
|
||||||
|
log.Error("Envelope is nil")
|
||||||
|
return nil, errors.New("envelope cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
writer := helpers.NewTLVWriter(buf)
|
||||||
|
|
||||||
|
log.Debug("Writing producerID to TLV",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
if err := writer.WriteStringField(env.ProducerID); err != nil {
|
||||||
|
log.Error("Failed to write producerID",
|
||||||
|
"error", err,
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to write producerID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Writing signature to TLV",
|
||||||
|
"signatureLength", len(env.Signature),
|
||||||
|
)
|
||||||
|
if err := writer.WriteField(env.Signature); err != nil {
|
||||||
|
log.Error("Failed to write signature",
|
||||||
|
"error", err,
|
||||||
|
"signatureLength", len(env.Signature),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to write signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Writing body to TLV",
|
||||||
|
"bodyLength", len(env.Body),
|
||||||
|
)
|
||||||
|
if err := writer.WriteField(env.Body); err != nil {
|
||||||
|
log.Error("Failed to write body",
|
||||||
|
"error", err,
|
||||||
|
"bodyLength", len(env.Body),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to write body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := buf.Bytes()
|
||||||
|
log.Debug("Envelope marshaled successfully",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
"totalLength", len(result),
|
||||||
|
)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalEnvelope 完整反序列化:读取所有字段。
|
||||||
|
// 解析完整的Envelope结构,包括所有元数据和Body。
|
||||||
|
// 为了向后兼容,如果遇到旧格式(包含原hash字段),会自动跳过该字段。
|
||||||
|
func UnmarshalEnvelope(data []byte) (*Envelope, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Unmarshaling envelope from TLV format",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Error("Data is empty")
|
||||||
|
return nil, errors.New("data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
reader := helpers.NewTLVReader(r)
|
||||||
|
|
||||||
|
log.Debug("Reading producerID from TLV")
|
||||||
|
producerID, err := reader.ReadStringField()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to read producerID",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to read producerID: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("ProducerID read successfully",
|
||||||
|
"producerID", producerID,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 读取第一个字段(可能是原hash或签名)
|
||||||
|
log.Debug("Reading field 1 from TLV")
|
||||||
|
field1, err := reader.ReadField()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to read field 1",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to read field 1: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("Field 1 read successfully",
|
||||||
|
"field1Length", len(field1),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 读取第二个字段(可能是签名或body)
|
||||||
|
log.Debug("Reading field 2 from TLV")
|
||||||
|
field2, err := reader.ReadField()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to read field 2",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to read field 2: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("Field 2 read successfully",
|
||||||
|
"field2Length", len(field2),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 尝试读取第三个字段来判断格式
|
||||||
|
log.Debug("Attempting to read field 3 to determine format")
|
||||||
|
field3, err := reader.ReadField()
|
||||||
|
if err == nil {
|
||||||
|
// 有第三个字段,说明是旧格式:producerID, originalHash, encryptedHash, body
|
||||||
|
// field1 = originalHash, field2 = encryptedHash/signature, field3 = body
|
||||||
|
log.Debug("Detected old format (with originalHash)",
|
||||||
|
"producerID", producerID,
|
||||||
|
"signatureLength", len(field2),
|
||||||
|
"bodyLength", len(field3),
|
||||||
|
)
|
||||||
|
return &Envelope{
|
||||||
|
ProducerID: producerID,
|
||||||
|
Signature: field2,
|
||||||
|
Body: field3,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 没有第三个字段,说明是新格式:producerID, signature, body
|
||||||
|
// field1 = signature, field2 = body
|
||||||
|
log.Debug("Detected new format (without originalHash)",
|
||||||
|
"producerID", producerID,
|
||||||
|
"signatureLength", len(field1),
|
||||||
|
"bodyLength", len(field2),
|
||||||
|
)
|
||||||
|
return &Envelope{
|
||||||
|
ProducerID: producerID,
|
||||||
|
Signature: field1,
|
||||||
|
Body: field2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 部分反序列化(无需反序列化全部报文) =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// UnmarshalEnvelopeProducerID 部分反序列化:只读取字段1(producerID)。
|
||||||
|
// 用于快速获取producerID而不解析整个Envelope。
|
||||||
|
func UnmarshalEnvelopeProducerID(data []byte) (string, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "", errors.New("data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
reader := helpers.NewTLVReader(r)
|
||||||
|
|
||||||
|
producerID, err := reader.ReadStringField()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read producerID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return producerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalEnvelopeSignature 部分反序列化:读取字段1、2(producerID, 签名)。
|
||||||
|
// 用于获取签名信息而不解析整个Body。
|
||||||
|
// 为了向后兼容,如果遇到旧格式(包含原hash字段),会自动跳过该字段。
|
||||||
|
func UnmarshalEnvelopeSignature(data []byte) (string, []byte, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "", nil, errors.New("data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
reader := helpers.NewTLVReader(r)
|
||||||
|
|
||||||
|
producerID, err := reader.ReadStringField()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to read producerID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取第一个字段(可能是原hash或签名)
|
||||||
|
field1, err := reader.ReadField()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to read field 1: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取第二个字段(可能是签名或body)
|
||||||
|
field2, err := reader.ReadField()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to read field 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试读取第三个字段来判断格式
|
||||||
|
_, err = reader.ReadField()
|
||||||
|
if err == nil {
|
||||||
|
// 有第三个字段,说明是旧格式:producerID, originalHash, encryptedHash/signature, body
|
||||||
|
// field1 = originalHash, field2 = signature
|
||||||
|
return producerID, field2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 没有第三个字段,说明是新格式:producerID, signature, body
|
||||||
|
// field1 = signature
|
||||||
|
return producerID, field1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== Trustlog 序列化/反序列化 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// MarshalTrustlog 序列化 Trustlog 为 Envelope 格式。
|
||||||
|
// Trustlog 实现了 Trustlog 接口,自动提取 producerID 并使用 Canonical CBOR 编码。
|
||||||
|
func MarshalTrustlog(t Trustlog, config EnvelopeConfig) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Marshaling Trustlog to Envelope format",
|
||||||
|
"trustlogType", fmt.Sprintf("%T", t),
|
||||||
|
)
|
||||||
|
if t == nil {
|
||||||
|
log.Error("Trustlog is nil")
|
||||||
|
return nil, errors.New("trustlog cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
producerID := t.GetProducerID()
|
||||||
|
if producerID == "" {
|
||||||
|
log.Error("ProducerID is empty")
|
||||||
|
return nil, errors.New("producerID cannot be empty")
|
||||||
|
}
|
||||||
|
log.Debug("ProducerID extracted",
|
||||||
|
"producerID", producerID,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 1. 序列化CBOR报文体(使用 Trustlog 的 MarshalBinary,确保使用 Canonical CBOR)
|
||||||
|
log.Debug("Marshaling trustlog to CBOR binary")
|
||||||
|
bodyCBOR, err := t.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to marshal trustlog to CBOR",
|
||||||
|
"error", err,
|
||||||
|
"producerID", producerID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to marshal trustlog: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("Trustlog marshaled to CBOR successfully",
|
||||||
|
"producerID", producerID,
|
||||||
|
"bodyLength", len(bodyCBOR),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 2. 计算签名
|
||||||
|
if config.Signer == nil {
|
||||||
|
log.Error("Signer is nil")
|
||||||
|
return nil, errors.New("signer is required")
|
||||||
|
}
|
||||||
|
log.Debug("Signing trustlog body",
|
||||||
|
"producerID", producerID,
|
||||||
|
"bodyLength", len(bodyCBOR),
|
||||||
|
)
|
||||||
|
signature, err := config.Signer.Sign(bodyCBOR)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign trustlog body",
|
||||||
|
"error", err,
|
||||||
|
"producerID", producerID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to sign data: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("Trustlog body signed successfully",
|
||||||
|
"producerID", producerID,
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 3. 构建Envelope
|
||||||
|
env := &Envelope{
|
||||||
|
ProducerID: producerID,
|
||||||
|
Signature: signature,
|
||||||
|
Body: bodyCBOR,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 序列化为TLV格式
|
||||||
|
log.Debug("Marshaling envelope to TLV format",
|
||||||
|
"producerID", producerID,
|
||||||
|
)
|
||||||
|
return MarshalEnvelope(env)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalTrustlog 反序列化 Envelope 为 Trustlog。
|
||||||
|
// 解析Envelope数据并恢复 Trustlog 结构。
|
||||||
|
func UnmarshalTrustlog(data []byte, t Trustlog) error {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Unmarshaling Envelope to Trustlog",
|
||||||
|
"trustlogType", fmt.Sprintf("%T", t),
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
if t == nil {
|
||||||
|
log.Error("Trustlog is nil")
|
||||||
|
return errors.New("trustlog cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
env, err := UnmarshalEnvelope(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to unmarshal envelope",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug("Envelope unmarshaled successfully",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
"bodyLength", len(env.Body),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 使用 Trustlog 的 UnmarshalBinary 反序列化
|
||||||
|
log.Debug("Unmarshaling trustlog body from CBOR",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
if errUnmarshal := t.UnmarshalBinary(env.Body); errUnmarshal != nil {
|
||||||
|
log.Error("Failed to unmarshal trustlog body",
|
||||||
|
"error", errUnmarshal,
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
return fmt.Errorf("failed to unmarshal trustlog body: %w", errUnmarshal)
|
||||||
|
}
|
||||||
|
log.Debug("Trustlog unmarshaled successfully",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== Operation 序列化/反序列化 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// MarshalOperation 序列化 Operation 为 Envelope 格式。
|
||||||
|
func MarshalOperation(op *Operation, config EnvelopeConfig) ([]byte, error) {
|
||||||
|
return MarshalTrustlog(op, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalOperation 反序列化 Envelope 为 Operation。
|
||||||
|
func UnmarshalOperation(data []byte) (*Operation, error) {
|
||||||
|
var op Operation
|
||||||
|
if err := UnmarshalTrustlog(data, &op); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &op, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== Record 序列化/反序列化 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// MarshalRecord 序列化 Record 为 Envelope 格式。
|
||||||
|
func MarshalRecord(record *Record, config EnvelopeConfig) ([]byte, error) {
|
||||||
|
return MarshalTrustlog(record, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalRecord 反序列化 Envelope 为 Record。
|
||||||
|
func UnmarshalRecord(data []byte) (*Record, error) {
|
||||||
|
var record Record
|
||||||
|
if err := UnmarshalTrustlog(data, &record); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 验证 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// VerifyEnvelope 验证Envelope的完整性(使用EnvelopeConfig)。
|
||||||
|
// 验证签名是否匹配,确保数据未被篡改。
|
||||||
|
// 如果验证成功,返回解析后的Envelope结构体指针;如果验证失败,返回错误。
|
||||||
|
func VerifyEnvelope(data []byte, config EnvelopeConfig) (*Envelope, error) {
|
||||||
|
if config.Signer == nil {
|
||||||
|
return nil, errors.New("signer is required for verification")
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyConfig := VerifyConfig(config)
|
||||||
|
return VerifyEnvelopeWithConfig(data, verifyConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyEnvelopeWithConfig 验证Envelope的完整性(使用VerifyConfig)。
|
||||||
|
// 验证签名是否匹配,确保数据未被篡改。
|
||||||
|
// 如果验证成功,返回解析后的Envelope结构体指针;如果验证失败,返回错误。
|
||||||
|
func VerifyEnvelopeWithConfig(data []byte, config VerifyConfig) (*Envelope, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying envelope",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
if config.Signer == nil {
|
||||||
|
log.Error("Signer is nil")
|
||||||
|
return nil, errors.New("signer is required for verification")
|
||||||
|
}
|
||||||
|
|
||||||
|
env, err := UnmarshalEnvelope(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to unmarshal envelope",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal envelope: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("Envelope unmarshaled for verification",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
"bodyLength", len(env.Body),
|
||||||
|
"signatureLength", len(env.Signature),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 验证签名
|
||||||
|
log.Debug("Verifying signature",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
valid, err := config.Signer.Verify(env.Body, env.Signature)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to verify signature",
|
||||||
|
"error", err,
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to verify signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
log.Warn("Signature verification failed",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
return nil, errors.New("signature verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Envelope verified successfully",
|
||||||
|
"producerID", env.ProducerID,
|
||||||
|
)
|
||||||
|
return env, nil
|
||||||
|
}
|
||||||
215
api/model/envelope_debug_test.go
Normal file
215
api/model/envelope_debug_test.go
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSignVerifyDataConsistency 详细测试加签和验签的数据一致性.
|
||||||
|
func TestSignVerifyDataConsistency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 序列化为DER格式
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名器
|
||||||
|
signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||||
|
|
||||||
|
// 测试数据1
|
||||||
|
testData1 := []byte("test data for signing")
|
||||||
|
|
||||||
|
// 测试数据2(不同数据)
|
||||||
|
testData2 := []byte("different test data")
|
||||||
|
|
||||||
|
// 1. 对testData1签名
|
||||||
|
signature1, err := signer.Sign(testData1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signature1)
|
||||||
|
|
||||||
|
// 2. 用testData1验证signature1 - 应该成功
|
||||||
|
valid, err := signer.Verify(testData1, signature1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, valid, "使用相同数据验证应该成功")
|
||||||
|
|
||||||
|
// 3. 用testData2验证signature1 - 应该失败
|
||||||
|
valid, err = signer.Verify(testData2, signature1)
|
||||||
|
require.Error(t, err, "使用不同数据验证应该失败")
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
// 4. 对testData2签名
|
||||||
|
signature2, err := signer.Sign(testData2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signature2)
|
||||||
|
|
||||||
|
// 5. 用testData2验证signature2 - 应该成功
|
||||||
|
valid, err = signer.Verify(testData2, signature2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, valid, "使用相同数据验证应该成功")
|
||||||
|
|
||||||
|
// 6. 用testData1验证signature2 - 应该失败
|
||||||
|
valid, err = signer.Verify(testData1, signature2)
|
||||||
|
require.Error(t, err, "使用不同数据验证应该失败")
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
t.Logf("测试完成:签名和验证逻辑正常")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnvelopeBodyTampering 测试修改envelope body后验签应该失败.
|
||||||
|
func TestEnvelopeBodyTampering(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 序列化为DER格式
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名配置
|
||||||
|
signConfig := model.NewSM2EnvelopeConfig(privateKeyDER, publicKeyDER)
|
||||||
|
verifyConfig := model.NewSM2VerifyConfig(publicKeyDER)
|
||||||
|
|
||||||
|
// 创建测试Operation
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-test-002",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/456",
|
||||||
|
ProducerID: "producer-2",
|
||||||
|
OpActor: "actor-2",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 1. 加签:序列化为Envelope
|
||||||
|
envelopeData, err := model.MarshalOperation(op, signConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, envelopeData)
|
||||||
|
|
||||||
|
// 2. 验签:验证原始Envelope - 应该成功
|
||||||
|
verifiedEnv, err := model.VerifyEnvelopeWithConfig(envelopeData, verifyConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, verifiedEnv)
|
||||||
|
|
||||||
|
// 3. 反序列化获取原始body
|
||||||
|
originalEnv, err := model.UnmarshalEnvelope(envelopeData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
originalBody := originalEnv.Body
|
||||||
|
originalSignature := originalEnv.Signature
|
||||||
|
|
||||||
|
t.Logf("原始body长度: %d", len(originalBody))
|
||||||
|
t.Logf("原始签名长度: %d", len(originalSignature))
|
||||||
|
|
||||||
|
// 4. 创建修改后的body(完全不同的数据)
|
||||||
|
modifiedBody := []byte("completely different body content")
|
||||||
|
require.NotEqual(t, originalBody, modifiedBody, "修改后的body应该不同")
|
||||||
|
|
||||||
|
// 5. 创建修改后的envelope(使用原始签名但修改body)
|
||||||
|
modifiedEnv := &model.Envelope{
|
||||||
|
ProducerID: originalEnv.ProducerID,
|
||||||
|
Signature: originalSignature, // 使用原始签名
|
||||||
|
Body: modifiedBody, // 使用修改后的body
|
||||||
|
}
|
||||||
|
modifiedData, err := model.MarshalEnvelope(modifiedEnv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 6. 验签修改后的envelope - 应该失败
|
||||||
|
_, err = model.VerifyEnvelopeWithConfig(modifiedData, verifyConfig)
|
||||||
|
require.Error(t, err, "修改body后验签应该失败")
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
|
||||||
|
t.Logf("测试完成:修改body后验签正确失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnvelopeSignatureTampering 测试修改envelope signature后验签应该失败.
|
||||||
|
func TestEnvelopeSignatureTampering(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 序列化为DER格式
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名配置
|
||||||
|
signConfig := model.NewSM2EnvelopeConfig(privateKeyDER, publicKeyDER)
|
||||||
|
verifyConfig := model.NewSM2VerifyConfig(publicKeyDER)
|
||||||
|
|
||||||
|
// 创建测试Operation
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-test-003",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/789",
|
||||||
|
ProducerID: "producer-3",
|
||||||
|
OpActor: "actor-3",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 1. 加签:序列化为Envelope
|
||||||
|
envelopeData, err := model.MarshalOperation(op, signConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 2. 反序列化获取原始signature
|
||||||
|
originalEnv, err := model.UnmarshalEnvelope(envelopeData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
originalSignature := originalEnv.Signature
|
||||||
|
|
||||||
|
// 3. 创建修改后的signature(完全不同的数据)
|
||||||
|
modifiedSignature := make([]byte, len(originalSignature))
|
||||||
|
copy(modifiedSignature, originalSignature)
|
||||||
|
// 修改最后一个字节
|
||||||
|
if len(modifiedSignature) > 0 {
|
||||||
|
modifiedSignature[len(modifiedSignature)-1] ^= 0xFF
|
||||||
|
}
|
||||||
|
require.NotEqual(t, originalSignature, modifiedSignature, "修改后的signature应该不同")
|
||||||
|
|
||||||
|
// 4. 创建修改后的envelope(使用原始body但修改signature)
|
||||||
|
modifiedEnv := &model.Envelope{
|
||||||
|
ProducerID: originalEnv.ProducerID,
|
||||||
|
Signature: modifiedSignature, // 使用修改后的signature
|
||||||
|
Body: originalEnv.Body, // 使用原始body
|
||||||
|
}
|
||||||
|
modifiedData, err := model.MarshalEnvelope(modifiedEnv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 5. 验签修改后的envelope - 应该失败
|
||||||
|
_, err = model.VerifyEnvelopeWithConfig(modifiedData, verifyConfig)
|
||||||
|
require.Error(t, err, "修改signature后验签应该失败")
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
|
||||||
|
t.Logf("测试完成:修改signature后验签正确失败")
|
||||||
|
}
|
||||||
126
api/model/envelope_sign_verify_test.go
Normal file
126
api/model/envelope_sign_verify_test.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSignVerifyConsistency 测试加签和验签的一致性
|
||||||
|
// 验证加签时使用的数据和验签时使用的数据是否一致.
|
||||||
|
func TestSignVerifyConsistency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 序列化为DER格式
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名配置
|
||||||
|
signConfig := model.NewSM2EnvelopeConfig(privateKeyDER, publicKeyDER)
|
||||||
|
verifyConfig := model.NewSM2VerifyConfig(publicKeyDER)
|
||||||
|
|
||||||
|
// 创建测试Operation
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-test-001",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 1. 加签:序列化为Envelope
|
||||||
|
envelopeData, err := model.MarshalOperation(op, signConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, envelopeData)
|
||||||
|
|
||||||
|
// 2. 验签:验证Envelope
|
||||||
|
verifiedEnv, err := model.VerifyEnvelopeWithConfig(envelopeData, verifyConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, verifiedEnv)
|
||||||
|
|
||||||
|
// 3. 验证:加签时使用的body和验签时使用的body应该一致
|
||||||
|
// 手动反序列化envelope以获取body
|
||||||
|
originalEnv, err := model.UnmarshalEnvelope(envelopeData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证body一致
|
||||||
|
assert.Equal(t, originalEnv.Body, verifiedEnv.Body, "加签和验签时使用的body应该完全一致")
|
||||||
|
assert.Equal(t, originalEnv.ProducerID, verifiedEnv.ProducerID)
|
||||||
|
assert.Equal(t, originalEnv.Signature, verifiedEnv.Signature)
|
||||||
|
|
||||||
|
// 4. 验证:如果修改body,验签应该失败
|
||||||
|
// 创建完全不同的body内容
|
||||||
|
modifiedBody := []byte("completely different body content")
|
||||||
|
require.NotEqual(t, originalEnv.Body, modifiedBody, "修改后的body应该不同")
|
||||||
|
|
||||||
|
modifiedEnv := &model.Envelope{
|
||||||
|
ProducerID: originalEnv.ProducerID,
|
||||||
|
Signature: originalEnv.Signature, // 使用旧的签名
|
||||||
|
Body: modifiedBody, // 使用修改后的body
|
||||||
|
}
|
||||||
|
modifiedData, err := model.MarshalEnvelope(modifiedEnv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验签应该失败,因为body被修改了但签名还是旧的
|
||||||
|
_, err = model.VerifyEnvelopeWithConfig(modifiedData, verifyConfig)
|
||||||
|
require.Error(t, err, "修改body后验签应该失败")
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSignVerifyDirectData 直接测试对相同数据的签名和验证.
|
||||||
|
func TestSignVerifyDirectData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 序列化为DER格式
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 创建签名器
|
||||||
|
signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||||
|
|
||||||
|
// 测试数据
|
||||||
|
testData := []byte("test data for signing")
|
||||||
|
|
||||||
|
// 1. 签名
|
||||||
|
signature, err := signer.Sign(testData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signature)
|
||||||
|
|
||||||
|
// 2. 验证(使用相同的数据)
|
||||||
|
valid, err := signer.Verify(testData, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, valid, "使用相同数据验证应该成功")
|
||||||
|
|
||||||
|
// 3. 验证(使用不同的数据)
|
||||||
|
modifiedData := []byte("modified test data")
|
||||||
|
valid, err = signer.Verify(modifiedData, signature)
|
||||||
|
// VerifySignature在验证失败时会返回错误,这是预期的
|
||||||
|
require.Error(t, err, "使用不同数据验证应该失败并返回错误")
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
assert.False(t, valid)
|
||||||
|
}
|
||||||
423
api/model/envelope_test.go
Normal file
423
api/model/envelope_test.go
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewEnvelopeConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
config := model.NewEnvelopeConfig(signer)
|
||||||
|
assert.NotNil(t, config.Signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSM2EnvelopeConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
privateKey := []byte("test-private-key")
|
||||||
|
publicKey := []byte("test-public-key")
|
||||||
|
|
||||||
|
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||||
|
assert.NotNil(t, config.Signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewVerifyConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
config := model.NewVerifyConfig(signer)
|
||||||
|
assert.NotNil(t, config.Signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSM2VerifyConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
publicKey := []byte("test-public-key")
|
||||||
|
|
||||||
|
config := model.NewSM2VerifyConfig(publicKey)
|
||||||
|
assert.NotNil(t, config.Signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalEnvelope_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.MarshalEnvelope(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "envelope cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalEnvelope_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
assert.NotEmpty(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalEnvelope_EmptyFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "",
|
||||||
|
Signature: []byte{},
|
||||||
|
Body: []byte{},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelope_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.UnmarshalEnvelope(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelope_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.UnmarshalEnvelope([]byte{})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalEnvelope_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := model.MarshalEnvelope(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
result, err := model.UnmarshalEnvelope(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||||
|
assert.Equal(t, original.Signature, result.Signature)
|
||||||
|
assert.Equal(t, original.Body, result.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelopeProducerID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
producerID, err := model.UnmarshalEnvelopeProducerID(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "producer-1", producerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelopeSignature(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
producerID, signature, err := model.UnmarshalEnvelopeSignature(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "producer-1", producerID)
|
||||||
|
assert.Equal(t, []byte("signature"), signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelopeSignature_EmptyData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, _, err := model.UnmarshalEnvelopeSignature(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelopeSignature_InvalidData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, _, err := model.UnmarshalEnvelopeSignature([]byte{0xff, 0xff})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalEnvelopeProducerID_EmptyData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.UnmarshalEnvelopeProducerID(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalTrustlog_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.MarshalTrustlog(nil, model.EnvelopeConfig{})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "trustlog cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalTrustlog_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
data, err := model.MarshalTrustlog(op, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalTrustlog_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{}
|
||||||
|
err := model.UnmarshalTrustlog(nil, op)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalOperation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
data, err := model.MarshalOperation(op, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalOperation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
data, err := model.MarshalOperation(op, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := model.UnmarshalOperation(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, op.OpID, result.OpID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rec.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
data, err := model.MarshalRecord(rec, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rec.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
data, err := model.MarshalRecord(rec, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := model.UnmarshalRecord(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, rec.ID, result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEnvelope_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
env, err := model.VerifyEnvelope(nil, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, env)
|
||||||
|
assert.Contains(t, err.Error(), "data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEnvelope_Basic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
verifiedEnv, err := model.VerifyEnvelope(data, config)
|
||||||
|
// NopSigner verifies by comparing body with signature
|
||||||
|
// Since signature != body, verification should fail
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, verifiedEnv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEnvelopeWithConfig_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config := model.NewVerifyConfig(model.NewNopSigner())
|
||||||
|
env, err := model.VerifyEnvelopeWithConfig(nil, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, env)
|
||||||
|
// Error message may vary, just check that it's an error
|
||||||
|
assert.NotEmpty(t, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEnvelopeWithConfig_NilSigner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.VerifyConfig{Signer: nil}
|
||||||
|
verifiedEnv, err := model.VerifyEnvelopeWithConfig(data, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, verifiedEnv)
|
||||||
|
assert.Contains(t, err.Error(), "signer is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEnvelopeWithConfig_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create envelope with matching body and signature (NopSigner requirement)
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("body"), // Same as body for NopSigner
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.NewVerifyConfig(model.NewNopSigner())
|
||||||
|
verifiedEnv, err := model.VerifyEnvelopeWithConfig(data, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, verifiedEnv)
|
||||||
|
assert.Equal(t, env.ProducerID, verifiedEnv.ProducerID)
|
||||||
|
assert.Equal(t, env.Signature, verifiedEnv.Signature)
|
||||||
|
assert.Equal(t, env.Body, verifiedEnv.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyEnvelope_NilSigner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
env := &model.Envelope{
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Signature: []byte("signature"),
|
||||||
|
Body: []byte("body"),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := model.MarshalEnvelope(env)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.EnvelopeConfig{Signer: nil}
|
||||||
|
verifiedEnv, err := model.VerifyEnvelope(data, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, verifiedEnv)
|
||||||
|
assert.Contains(t, err.Error(), "signer is required")
|
||||||
|
}
|
||||||
267
api/model/hash.go
Normal file
267
api/model/hash.go
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/sha1"
|
||||||
|
stdsha256 "crypto/sha256"
|
||||||
|
stdsha512 "crypto/sha512"
|
||||||
|
"encoding/hex"
|
||||||
|
"hash"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
miniosha256 "github.com/minio/sha256-simd"
|
||||||
|
"github.com/zeebo/blake3"
|
||||||
|
"golang.org/x/crypto/blake2b"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"golang.org/x/crypto/md4" //nolint:staticcheck // 保留弱加密算法以支持遗留系统兼容性
|
||||||
|
"golang.org/x/crypto/ripemd160" //nolint:staticcheck // 保留弱加密算法以支持遗留系统兼容性
|
||||||
|
"golang.org/x/crypto/sha3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HashType 定义支持的哈希算法类型.
|
||||||
|
type HashType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MD5 HashType = "md5"
|
||||||
|
SHA1 HashType = "sha1"
|
||||||
|
SHA224 HashType = "sha224"
|
||||||
|
SHA256 HashType = "sha256"
|
||||||
|
SHA384 HashType = "sha384"
|
||||||
|
SHA512 HashType = "sha512"
|
||||||
|
Sha512224 HashType = "sha512_224"
|
||||||
|
Sha512256 HashType = "sha512_256"
|
||||||
|
|
||||||
|
Sha256Simd HashType = "sha256-simd"
|
||||||
|
BLAKE3 HashType = "blake3"
|
||||||
|
BLAKE2B HashType = "blake2b"
|
||||||
|
BLAKE2S HashType = "blake2s"
|
||||||
|
MD4 HashType = "md4"
|
||||||
|
RIPEMD160 HashType = "ripemd160"
|
||||||
|
Sha3224 HashType = "sha3-224"
|
||||||
|
Sha3256 HashType = "sha3-256"
|
||||||
|
Sha3384 HashType = "sha3-384"
|
||||||
|
Sha3512 HashType = "sha3-512"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 使用 map 来存储支持的算法,提高查找效率.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // 全局缓存用于算法查找和实例复用.
|
||||||
|
var (
|
||||||
|
supportedAlgorithms []string
|
||||||
|
supportedAlgorithmsMap map[string]bool
|
||||||
|
supportedAlgorithmsOnce sync.Once
|
||||||
|
|
||||||
|
// 享元模式:存储已创建的 HashTool 实例.
|
||||||
|
toolPool = make(map[HashType]*HashTool)
|
||||||
|
poolMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// HashTool 哈希工具类.
|
||||||
|
type HashTool struct {
|
||||||
|
hashType HashType
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHashTool 获取指定类型的 HashTool.
|
||||||
|
func GetHashTool(hashType HashType) *HashTool {
|
||||||
|
poolMutex.RLock()
|
||||||
|
if tool, exists := toolPool[hashType]; exists {
|
||||||
|
poolMutex.RUnlock()
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
poolMutex.RUnlock()
|
||||||
|
|
||||||
|
poolMutex.Lock()
|
||||||
|
defer poolMutex.Unlock()
|
||||||
|
|
||||||
|
if tool, exists := toolPool[hashType]; exists {
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
|
||||||
|
tool := &HashTool{hashType: hashType}
|
||||||
|
toolPool[hashType] = tool
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHashTool 创建新的哈希工具实例.
|
||||||
|
func NewHashTool(hashType HashType) *HashTool {
|
||||||
|
return &HashTool{hashType: hashType}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHasher 根据哈希类型获取对应的哈希器.
|
||||||
|
func (h *HashTool) getHasher() hash.Hash {
|
||||||
|
switch h.hashType {
|
||||||
|
case MD5:
|
||||||
|
return md5.New()
|
||||||
|
case SHA1:
|
||||||
|
return sha1.New()
|
||||||
|
case SHA224:
|
||||||
|
return stdsha256.New224()
|
||||||
|
case SHA256:
|
||||||
|
return stdsha256.New()
|
||||||
|
case SHA384:
|
||||||
|
return stdsha512.New384()
|
||||||
|
case SHA512:
|
||||||
|
return stdsha512.New()
|
||||||
|
case Sha512224:
|
||||||
|
return stdsha512.New512_224()
|
||||||
|
case Sha512256:
|
||||||
|
return stdsha512.New512_256()
|
||||||
|
|
||||||
|
// 第三方算法
|
||||||
|
case Sha256Simd:
|
||||||
|
return miniosha256.New()
|
||||||
|
case BLAKE3:
|
||||||
|
return blake3.New()
|
||||||
|
case BLAKE2B:
|
||||||
|
hasher, _ := blake2b.New512(nil)
|
||||||
|
return hasher
|
||||||
|
case BLAKE2S:
|
||||||
|
hasher, _ := blake2s.New256(nil)
|
||||||
|
return hasher
|
||||||
|
case MD4:
|
||||||
|
return md4.New()
|
||||||
|
case RIPEMD160:
|
||||||
|
return ripemd160.New()
|
||||||
|
case Sha3224:
|
||||||
|
return sha3.New224()
|
||||||
|
case Sha3256:
|
||||||
|
return sha3.New256()
|
||||||
|
case Sha3384:
|
||||||
|
return sha3.New384()
|
||||||
|
case Sha3512:
|
||||||
|
return sha3.New512()
|
||||||
|
|
||||||
|
default:
|
||||||
|
return stdsha256.New() // 默认使用 SHA256
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashData 通用的哈希计算函数.
|
||||||
|
func (h *HashTool) hashData(processFunc func(hasher hash.Hash) error) (string, error) {
|
||||||
|
hasher := h.getHasher()
|
||||||
|
if err := processFunc(hasher); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashString 对字符串进行哈希计算.
|
||||||
|
func (h *HashTool) HashString(data string) (string, error) {
|
||||||
|
return h.hashData(func(hasher hash.Hash) error {
|
||||||
|
_, err := hasher.Write([]byte(data))
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashBytes 对字节数组进行哈希计算.
|
||||||
|
func (h *HashTool) HashBytes(data []byte) (string, error) {
|
||||||
|
return h.hashData(func(hasher hash.Hash) error {
|
||||||
|
_, err := hasher.Write(data)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashBytesRaw 对字节数组进行哈希计算,返回原始字节数组(非hex字符串).
|
||||||
|
func (h *HashTool) HashBytesRaw(data []byte) ([]byte, error) {
|
||||||
|
hasher := h.getHasher()
|
||||||
|
if _, err := hasher.Write(data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hasher.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashFile 对文件进行哈希计算.
|
||||||
|
func (h *HashTool) HashFile(_ context.Context, filePath string) (string, error) {
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
return h.hashData(func(hasher hash.Hash) error {
|
||||||
|
_, copyErr := io.Copy(hasher, file)
|
||||||
|
return copyErr
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashStream 对流数据进行哈希计算.
|
||||||
|
func (h *HashTool) HashStream(reader io.Reader) (string, error) {
|
||||||
|
return h.hashData(func(hasher hash.Hash) error {
|
||||||
|
_, err := io.Copy(hasher, reader)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// initSupportedAlgorithms 初始化支持的算法数据.
|
||||||
|
func initSupportedAlgorithms() {
|
||||||
|
algorithms := []HashType{
|
||||||
|
MD5, SHA1, SHA224, SHA256, SHA384, SHA512,
|
||||||
|
Sha512224, Sha512256, Sha256Simd, BLAKE3,
|
||||||
|
BLAKE2B, BLAKE2S, MD4, RIPEMD160,
|
||||||
|
Sha3224, Sha3256, Sha3384, Sha3512,
|
||||||
|
}
|
||||||
|
|
||||||
|
supportedAlgorithms = make([]string, len(algorithms))
|
||||||
|
supportedAlgorithmsMap = make(map[string]bool, len(algorithms))
|
||||||
|
|
||||||
|
for i, alg := range algorithms {
|
||||||
|
algStr := string(alg)
|
||||||
|
supportedAlgorithms[i] = algStr
|
||||||
|
supportedAlgorithmsMap[strings.ToLower(algStr)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSupportedAlgorithms 获取支持的哈希算法列表.
|
||||||
|
func GetSupportedAlgorithms() []string {
|
||||||
|
supportedAlgorithmsOnce.Do(initSupportedAlgorithms)
|
||||||
|
return supportedAlgorithms
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAlgorithmSupported 检查算法是否支持 - 使用 map 提高性能.
|
||||||
|
func IsAlgorithmSupported(algorithm string) bool {
|
||||||
|
supportedAlgorithmsOnce.Do(initSupportedAlgorithms)
|
||||||
|
return supportedAlgorithmsMap[strings.ToLower(algorithm)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareHash 比较哈希值.
|
||||||
|
func (h *HashTool) CompareHash(data, expectedHash string) (bool, error) {
|
||||||
|
actualHash, err := h.HashString(data)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return strings.EqualFold(actualHash, expectedHash), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareFileHash 比较文件哈希值.
|
||||||
|
func (h *HashTool) CompareFileHash(ctx context.Context, filePath, expectedHash string) (bool, error) {
|
||||||
|
actualHash, err := h.HashFile(ctx, filePath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return strings.EqualFold(actualHash, expectedHash), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHashType 获取当前工具使用的哈希类型.
|
||||||
|
func (h *HashTool) GetHashType() HashType {
|
||||||
|
return h.hashType
|
||||||
|
}
|
||||||
|
|
||||||
|
type HashData interface {
|
||||||
|
Key() string
|
||||||
|
Hash() string
|
||||||
|
Type() HashType
|
||||||
|
}
|
||||||
|
|
||||||
|
type Hashable interface {
|
||||||
|
DoHash(ctx context.Context) (HashData, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type HashList []HashData
|
||||||
|
|
||||||
|
func (h HashList) GetHashType() HashType {
|
||||||
|
return h[0].Type()
|
||||||
|
}
|
||||||
545
api/model/hash_test.go
Normal file
545
api/model/hash_test.go
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetHashTool(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hashType model.HashType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SHA256",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA256Simd",
|
||||||
|
hashType: model.Sha256Simd,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MD5",
|
||||||
|
hashType: model.MD5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA1",
|
||||||
|
hashType: model.SHA1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tool := model.GetHashTool(tt.hashType)
|
||||||
|
assert.NotNil(t, tool)
|
||||||
|
// Verify it works
|
||||||
|
_, err := tool.HashString("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Verify hash type
|
||||||
|
assert.Equal(t, tt.hashType, tool.GetHashType())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHashTool(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
assert.NotNil(t, tool)
|
||||||
|
// Verify it works
|
||||||
|
_, err := tool.HashString("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_HashString(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hashType model.HashType
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SHA256",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA256Simd",
|
||||||
|
hashType: model.Sha256Simd,
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MD5",
|
||||||
|
hashType: model.MD5,
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA1",
|
||||||
|
hashType: model.SHA1,
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA512",
|
||||||
|
hashType: model.SHA512,
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
input: "",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tool := model.NewHashTool(tt.hashType)
|
||||||
|
result, err := tool.HashString(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_HashBytes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hashType model.HashType
|
||||||
|
input []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SHA256",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
input: []byte("test"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA256Simd",
|
||||||
|
hashType: model.Sha256Simd,
|
||||||
|
input: []byte("test"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty bytes",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
input: []byte{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large input",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
input: make([]byte, 1000),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tool := model.NewHashTool(tt.hashType)
|
||||||
|
result, err := tool.HashBytes(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_Deterministic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
input := "test string"
|
||||||
|
|
||||||
|
result1, err1 := tool.HashString(input)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
|
||||||
|
result2, err2 := tool.HashString(input)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
|
||||||
|
// Same input should produce same hash
|
||||||
|
assert.Equal(t, result1, result2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_DifferentInputs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
|
||||||
|
result1, err1 := tool.HashString("input1")
|
||||||
|
require.NoError(t, err1)
|
||||||
|
|
||||||
|
result2, err2 := tool.HashString("input2")
|
||||||
|
require.NoError(t, err2)
|
||||||
|
|
||||||
|
// Different inputs should produce different hashes
|
||||||
|
assert.NotEqual(t, result1, result2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_StringVsBytes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
input := "test"
|
||||||
|
|
||||||
|
stringHash, err1 := tool.HashString(input)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
|
||||||
|
bytesHash, err2 := tool.HashBytes([]byte(input))
|
||||||
|
require.NoError(t, err2)
|
||||||
|
|
||||||
|
// Same data in different formats should produce same hash
|
||||||
|
assert.Equal(t, stringHash, bytesHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_MultipleTypes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
input := "test"
|
||||||
|
hashTypes := []model.HashType{
|
||||||
|
model.MD5,
|
||||||
|
model.SHA1,
|
||||||
|
model.SHA256,
|
||||||
|
model.SHA512,
|
||||||
|
model.Sha256Simd,
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make(map[model.HashType]string)
|
||||||
|
for _, hashType := range hashTypes {
|
||||||
|
tool := model.NewHashTool(hashType)
|
||||||
|
result, err := tool.HashString(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
results[hashType] = result
|
||||||
|
}
|
||||||
|
|
||||||
|
// All should produce different hashes (except possibly some edge cases)
|
||||||
|
// At minimum, verify they all produced valid hashes
|
||||||
|
for hashType, result := range results {
|
||||||
|
assert.NotEmpty(t, result, "HashType: %v", hashType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_GetHashTool_Caching(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
hashType := model.SHA256
|
||||||
|
tool1 := model.GetHashTool(hashType)
|
||||||
|
tool2 := model.GetHashTool(hashType)
|
||||||
|
|
||||||
|
// Should return the same instance (cached)
|
||||||
|
assert.Equal(t, tool1, tool2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_HashFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create a temporary file
|
||||||
|
tmpFile := t.TempDir() + "/test.txt"
|
||||||
|
err := os.WriteFile(tmpFile, []byte("test content"), 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := tool.HashFile(ctx, tmpFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_HashFile_NotExists(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := tool.HashFile(ctx, "/nonexistent/file")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_HashStream(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
reader := bytes.NewReader([]byte("test content"))
|
||||||
|
|
||||||
|
result, err := tool.HashStream(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_HashStream_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
reader := bytes.NewReader([]byte{})
|
||||||
|
|
||||||
|
result, err := tool.HashStream(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, result) // Even empty input produces a hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSupportedAlgorithms(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
algorithms := model.GetSupportedAlgorithms()
|
||||||
|
assert.NotEmpty(t, algorithms)
|
||||||
|
assert.Contains(t, algorithms, string(model.SHA256))
|
||||||
|
assert.Contains(t, algorithms, string(model.Sha256Simd))
|
||||||
|
// Verify case-insensitive check
|
||||||
|
assert.True(t, model.IsAlgorithmSupported("SHA256"))
|
||||||
|
assert.True(t, model.IsAlgorithmSupported("sha256"))
|
||||||
|
assert.True(t, model.IsAlgorithmSupported("sha256-simd"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAlgorithmSupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
algorithm string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SHA256",
|
||||||
|
algorithm: "SHA256",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SHA256 lowercase",
|
||||||
|
algorithm: "sha256",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Sha256Simd",
|
||||||
|
algorithm: "sha256-simd",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Sha256Simd mixed case",
|
||||||
|
algorithm: "Sha256-Simd",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported",
|
||||||
|
algorithm: "UNSUPPORTED",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
result := model.IsAlgorithmSupported(tt.algorithm)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_GetHashType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA512)
|
||||||
|
assert.Equal(t, model.SHA512, tool.GetHashType())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_AllHashTypes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
hashTypes := []model.HashType{
|
||||||
|
model.MD5,
|
||||||
|
model.SHA1,
|
||||||
|
model.SHA224,
|
||||||
|
model.SHA256,
|
||||||
|
model.SHA384,
|
||||||
|
model.SHA512,
|
||||||
|
model.Sha256Simd,
|
||||||
|
model.BLAKE3,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, hashType := range hashTypes {
|
||||||
|
tool := model.NewHashTool(hashType)
|
||||||
|
result, err := tool.HashString("test")
|
||||||
|
require.NoError(t, err, "HashType: %v", hashType)
|
||||||
|
assert.NotEmpty(t, result, "HashType: %v", hashType)
|
||||||
|
assert.Equal(t, hashType, tool.GetHashType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_CompareHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
data := "test data"
|
||||||
|
|
||||||
|
// Generate hash
|
||||||
|
hash, err := tool.HashString(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data string
|
||||||
|
expectedHash string
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "匹配的哈希值",
|
||||||
|
data: data,
|
||||||
|
expectedHash: hash,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "大小写不同但内容相同",
|
||||||
|
data: data,
|
||||||
|
expectedHash: strings.ToUpper(hash),
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "不匹配的哈希值",
|
||||||
|
data: data,
|
||||||
|
expectedHash: "invalid_hash",
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "不同的数据",
|
||||||
|
data: "different data",
|
||||||
|
expectedHash: hash,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
match, err := tool.CompareHash(tt.data, tt.expectedHash)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.shouldMatch, match)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashTool_CompareFileHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create a temporary file
|
||||||
|
tmpFile := t.TempDir() + "/test.txt"
|
||||||
|
content := []byte("test file content")
|
||||||
|
err := os.WriteFile(tmpFile, content, 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tool := model.NewHashTool(model.SHA256)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Generate expected hash
|
||||||
|
expectedHash, err := tool.HashFile(ctx, tmpFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filePath string
|
||||||
|
expectedHash string
|
||||||
|
shouldMatch bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "匹配的文件哈希",
|
||||||
|
filePath: tmpFile,
|
||||||
|
expectedHash: expectedHash,
|
||||||
|
shouldMatch: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "大小写不同但内容相同",
|
||||||
|
filePath: tmpFile,
|
||||||
|
expectedHash: strings.ToUpper(expectedHash),
|
||||||
|
shouldMatch: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "不匹配的文件哈希",
|
||||||
|
filePath: tmpFile,
|
||||||
|
expectedHash: "invalid_hash",
|
||||||
|
shouldMatch: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "文件不存在",
|
||||||
|
filePath: "/nonexistent/file",
|
||||||
|
expectedHash: expectedHash,
|
||||||
|
shouldMatch: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
match, err := tool.CompareFileHash(ctx, tt.filePath, tt.expectedHash)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.shouldMatch, match)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashList_GetHashType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create mock hash data
|
||||||
|
mockHash := &mockHashData{
|
||||||
|
key: "test-key",
|
||||||
|
hash: "test-hash",
|
||||||
|
hashType: model.SHA256,
|
||||||
|
}
|
||||||
|
|
||||||
|
hashList := model.HashList{mockHash}
|
||||||
|
assert.Equal(t, model.SHA256, hashList.GetHashType())
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockHashData implements HashData interface for testing.
|
||||||
|
type mockHashData struct {
|
||||||
|
key string
|
||||||
|
hash string
|
||||||
|
hashType model.HashType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHashData) Key() string {
|
||||||
|
return m.key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHashData) Hash() string {
|
||||||
|
return m.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHashData) Type() model.HashType {
|
||||||
|
return m.hashType
|
||||||
|
}
|
||||||
577
api/model/operation.go
Normal file
577
api/model/operation.go
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 操作来源类型 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// Source 表示操作来源,用于区分不同系统模块(IRP、DOIP)。
|
||||||
|
type Source string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpSourceIRP Source = "IRP"
|
||||||
|
OpSourceDOIP Source = "DOIP"
|
||||||
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 操作类型枚举 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// Type 表示操作的具体类型。
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
// DOIP 操作类型枚举。
|
||||||
|
const (
|
||||||
|
OpTypeHello Type = "Hello"
|
||||||
|
OpTypeRetrieve Type = "Retrieve"
|
||||||
|
OpTypeCreate Type = "Create"
|
||||||
|
OpTypeDelete Type = "Delete"
|
||||||
|
OpTypeUpdate Type = "Update"
|
||||||
|
OpTypeSearch Type = "Search"
|
||||||
|
OpTypeListOperations Type = "ListOperations"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IRP 操作类型枚举。
|
||||||
|
const (
|
||||||
|
OpTypeOCReserved Type = "OC_RESERVED"
|
||||||
|
OpTypeOCResolution Type = "OC_RESOLUTION"
|
||||||
|
OpTypeOCGetSiteInfo Type = "OC_GET_SITEINFO"
|
||||||
|
OpTypeOCCreateHandle Type = "OC_CREATE_HANDLE"
|
||||||
|
OpTypeOCDeleteHandle Type = "OC_DELETE_HANDLE"
|
||||||
|
OpTypeOCAddValue Type = "OC_ADD_VALUE"
|
||||||
|
OpTypeOCRemoveValue Type = "OC_REMOVE_VALUE"
|
||||||
|
OpTypeOCModifyValue Type = "OC_MODIFY_VALUE"
|
||||||
|
OpTypeOCListHandle Type = "OC_LIST_HANDLE"
|
||||||
|
OpTypeOCListNA Type = "OC_LIST_NA"
|
||||||
|
OpTypeOCResolutionDOID Type = "OC_RESOLUTION_DOID"
|
||||||
|
OpTypeOCCreateDOID Type = "OC_CREATE_DOID"
|
||||||
|
OpTypeOCDeleteDOID Type = "OC_DELETE_DOID"
|
||||||
|
OpTypeOCUpdateDOID Type = "OC_UPDATE_DOID"
|
||||||
|
OpTypeOCBatchCreateDOID Type = "OC_BATCH_CREATE_DOID"
|
||||||
|
OpTypeOCResolutionDOIDRecursive Type = "OC_RESOLUTION_DOID_RECURSIVE"
|
||||||
|
OpTypeOCGetUsers Type = "OC_GET_USERS"
|
||||||
|
OpTypeOCGetRepos Type = "OC_GET_REPOS"
|
||||||
|
OpTypeOCVerifyIRS Type = "OC_VERIFY_IRS"
|
||||||
|
OpTypeOCResolveGRS Type = "OC_RESOLVE_GRS"
|
||||||
|
OpTypeOCCreateOrgGRS Type = "OC_CREATE_ORG_GRS"
|
||||||
|
OpTypeOCUpdateOrgGRS Type = "OC_UPDATE_ORG_GRS"
|
||||||
|
OpTypeOCDeleteOrgGRS Type = "OC_DELETE_ORG_GRS"
|
||||||
|
OpTypeOCSyncOrgIRSParent Type = "OC_SYNC_ORG_IRS_PARENT"
|
||||||
|
OpTypeOCUpdateOrgIRSParent Type = "OC_UPDATE_ORG_IRS_PARENT"
|
||||||
|
OpTypeOCDeleteOrgIRSParent Type = "OC_DELETE_ORG_IRS_PARENT"
|
||||||
|
OpTypeOCChallengeResponse Type = "OC_CHALLENGE_RESPONSE"
|
||||||
|
OpTypeOCVerifyChallenge Type = "OC_VERIFY_CHALLENGE"
|
||||||
|
OpTypeOCSessionSetup Type = "OC_SESSION_SETUP"
|
||||||
|
OpTypeOCSessionTerminate Type = "OC_SESSION_TERMINATE"
|
||||||
|
OpTypeOCSessionExchangeKey Type = "OC_SESSION_EXCHANGEKEY"
|
||||||
|
OpTypeOCVerifyRouter Type = "OC_VERIFY_ROUTER"
|
||||||
|
OpTypeOCQueryRouter Type = "OC_QUERY_ROUTER"
|
||||||
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 操作类型检索工具 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// allOpTypes 存储不同来源的操作类型列表,用于快速查找和验证。
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // 全局常量映射用于操作类型查找
|
||||||
|
var allOpTypes = map[Source][]Type{
|
||||||
|
OpSourceDOIP: {
|
||||||
|
OpTypeHello, OpTypeRetrieve, OpTypeCreate,
|
||||||
|
OpTypeDelete, OpTypeUpdate, OpTypeSearch,
|
||||||
|
OpTypeListOperations,
|
||||||
|
},
|
||||||
|
OpSourceIRP: {
|
||||||
|
OpTypeOCReserved, OpTypeOCResolution, OpTypeOCGetSiteInfo,
|
||||||
|
OpTypeOCCreateHandle, OpTypeOCDeleteHandle, OpTypeOCAddValue,
|
||||||
|
OpTypeOCRemoveValue, OpTypeOCModifyValue, OpTypeOCListHandle,
|
||||||
|
OpTypeOCListNA, OpTypeOCResolutionDOID, OpTypeOCCreateDOID,
|
||||||
|
OpTypeOCDeleteDOID, OpTypeOCUpdateDOID, OpTypeOCBatchCreateDOID,
|
||||||
|
OpTypeOCResolutionDOIDRecursive, OpTypeOCGetUsers, OpTypeOCGetRepos,
|
||||||
|
OpTypeOCVerifyIRS, OpTypeOCResolveGRS, OpTypeOCCreateOrgGRS,
|
||||||
|
OpTypeOCUpdateOrgGRS, OpTypeOCDeleteOrgGRS, OpTypeOCSyncOrgIRSParent,
|
||||||
|
OpTypeOCUpdateOrgIRSParent, OpTypeOCDeleteOrgIRSParent,
|
||||||
|
OpTypeOCChallengeResponse, OpTypeOCVerifyChallenge,
|
||||||
|
OpTypeOCSessionSetup, OpTypeOCSessionTerminate,
|
||||||
|
OpTypeOCSessionExchangeKey, OpTypeOCVerifyRouter, OpTypeOCQueryRouter,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpTypesBySource 返回指定来源的可用操作类型列表。
|
||||||
|
func GetOpTypesBySource(source Source) []Type {
|
||||||
|
return allOpTypes[source]
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidOpType 判断指定操作类型在给定来源下是否合法。
|
||||||
|
func IsValidOpType(source Source, opType Type) bool {
|
||||||
|
for _, t := range GetOpTypesBySource(source) {
|
||||||
|
if t == opType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 操作记录结构 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// Operation 表示一次完整的操作记录。
|
||||||
|
// 用于记录系统中的操作行为,包含操作元数据、数据标识、操作者信息以及请求/响应的哈希值。
|
||||||
|
type Operation struct {
|
||||||
|
OpID string `json:"opId" validate:"max=32"`
|
||||||
|
Timestamp time.Time `json:"timestamp" validate:"required"`
|
||||||
|
OpSource Source `json:"opSource" validate:"required,oneof=IRP DOIP"`
|
||||||
|
OpType Type `json:"opType" validate:"required"`
|
||||||
|
DoPrefix string `json:"doPrefix" validate:"required,max=512"`
|
||||||
|
DoRepository string `json:"doRepository" validate:"required,max=512"`
|
||||||
|
Doid string `json:"doid" validate:"required,max=512"`
|
||||||
|
ProducerID string `json:"producerId" validate:"required,max=512"`
|
||||||
|
OpActor string `json:"opActor" validate:"max=64"`
|
||||||
|
RequestBodyHash *string `json:"requestBodyHash" validate:"omitempty,max=128"`
|
||||||
|
ResponseBodyHash *string `json:"responseBodyHash" validate:"omitempty,max=128"`
|
||||||
|
Ack func() bool `json:"-"`
|
||||||
|
Nack func() bool `json:"-"`
|
||||||
|
binary []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 构造函数 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// NewFullOperation 创建包含所有字段的完整 Operation。
|
||||||
|
// 自动完成哈希计算和字段校验,确保创建的 Operation 是完整且有效的。
|
||||||
|
func NewFullOperation(
|
||||||
|
opSource Source,
|
||||||
|
opType Type,
|
||||||
|
doPrefix, doRepository, doid string,
|
||||||
|
producerID string,
|
||||||
|
opActor string,
|
||||||
|
requestBody, responseBody interface{},
|
||||||
|
timestamp time.Time,
|
||||||
|
) (*Operation, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new full operation",
|
||||||
|
"opSource", opSource,
|
||||||
|
"opType", opType,
|
||||||
|
"doPrefix", doPrefix,
|
||||||
|
"doRepository", doRepository,
|
||||||
|
"doid", doid,
|
||||||
|
"producerID", producerID,
|
||||||
|
"opActor", opActor,
|
||||||
|
)
|
||||||
|
op := &Operation{
|
||||||
|
Timestamp: timestamp,
|
||||||
|
OpSource: opSource,
|
||||||
|
OpType: opType,
|
||||||
|
DoPrefix: doPrefix,
|
||||||
|
DoRepository: doRepository,
|
||||||
|
Doid: doid,
|
||||||
|
ProducerID: producerID,
|
||||||
|
OpActor: opActor,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Setting request body hash")
|
||||||
|
if err := op.RequestBodyFlexible(requestBody); err != nil {
|
||||||
|
log.Error("Failed to set request body hash",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Debug("Setting response body hash")
|
||||||
|
if err := op.ResponseBodyFlexible(responseBody); err != nil {
|
||||||
|
log.Error("Failed to set response body hash",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Debug("Checking and initializing operation")
|
||||||
|
if err := op.CheckAndInit(); err != nil {
|
||||||
|
log.Error("Failed to check and init operation",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Full operation created successfully",
|
||||||
|
"opID", op.OpID,
|
||||||
|
)
|
||||||
|
return op, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 接口实现 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
func (o *Operation) Key() string {
|
||||||
|
return o.OpID
|
||||||
|
}
|
||||||
|
|
||||||
|
// OperationHashData 实现 HashData 接口,用于存储 Operation 的哈希计算结果。
|
||||||
|
type OperationHashData struct {
|
||||||
|
key string
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OperationHashData) Key() string {
|
||||||
|
return o.key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OperationHashData) Hash() string {
|
||||||
|
return o.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OperationHashData) Type() HashType {
|
||||||
|
return Sha256Simd
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoHash 计算 Operation 的整体哈希值,用于数据完整性验证。
|
||||||
|
// 哈希基于序列化后的二进制数据计算,确保操作记录的不可篡改性。
|
||||||
|
func (o *Operation) DoHash(_ context.Context) (HashData, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Computing hash for operation",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
hashTool := GetHashTool(Sha256Simd)
|
||||||
|
binary, err := o.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to marshal operation for hash",
|
||||||
|
"error", err,
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to marshal operation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Computing hash bytes",
|
||||||
|
"opID", o.OpID,
|
||||||
|
"binaryLength", len(binary),
|
||||||
|
)
|
||||||
|
hash, err := hashTool.HashBytes(binary)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to compute hash",
|
||||||
|
"error", err,
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to compute hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Hash computed successfully",
|
||||||
|
"opID", o.OpID,
|
||||||
|
"hash", hash,
|
||||||
|
)
|
||||||
|
return OperationHashData{
|
||||||
|
key: o.OpID,
|
||||||
|
hash: hash,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== CBOR 序列化相关 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// operationData 用于 CBOR 序列化/反序列化的中间结构。
|
||||||
|
// 排除函数字段和缓存字段,仅包含可序列化的数据字段。
|
||||||
|
type operationData struct {
|
||||||
|
OpID *string `cbor:"opId"`
|
||||||
|
Timestamp *time.Time `cbor:"timestamp"`
|
||||||
|
OpSource *Source `cbor:"opSource"`
|
||||||
|
OpType *Type `cbor:"opType"`
|
||||||
|
DoPrefix *string `cbor:"doPrefix"`
|
||||||
|
DoRepository *string `cbor:"doRepository"`
|
||||||
|
Doid *string `cbor:"doid"`
|
||||||
|
ProducerID *string `cbor:"producerId"`
|
||||||
|
OpActor *string `cbor:"opActor"`
|
||||||
|
RequestBodyHash *string `cbor:"requestBodyHash"`
|
||||||
|
ResponseBodyHash *string `cbor:"responseBodyHash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// toOperationData 将 Operation 转换为 operationData,用于序列化。
|
||||||
|
func (o *Operation) toOperationData() *operationData {
|
||||||
|
return &operationData{
|
||||||
|
OpID: &o.OpID,
|
||||||
|
Timestamp: &o.Timestamp,
|
||||||
|
OpSource: &o.OpSource,
|
||||||
|
OpType: &o.OpType,
|
||||||
|
DoPrefix: &o.DoPrefix,
|
||||||
|
DoRepository: &o.DoRepository,
|
||||||
|
Doid: &o.Doid,
|
||||||
|
ProducerID: &o.ProducerID,
|
||||||
|
OpActor: &o.OpActor,
|
||||||
|
RequestBodyHash: o.RequestBodyHash,
|
||||||
|
ResponseBodyHash: o.ResponseBodyHash,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fromOperationData 从 operationData 填充 Operation,用于反序列化。
|
||||||
|
func (o *Operation) fromOperationData(opData *operationData) {
|
||||||
|
if opData == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if opData.OpID != nil {
|
||||||
|
o.OpID = *opData.OpID
|
||||||
|
}
|
||||||
|
if opData.Timestamp != nil {
|
||||||
|
o.Timestamp = *opData.Timestamp
|
||||||
|
}
|
||||||
|
if opData.OpSource != nil {
|
||||||
|
o.OpSource = *opData.OpSource
|
||||||
|
}
|
||||||
|
if opData.OpType != nil {
|
||||||
|
o.OpType = *opData.OpType
|
||||||
|
}
|
||||||
|
if opData.DoPrefix != nil {
|
||||||
|
o.DoPrefix = *opData.DoPrefix
|
||||||
|
}
|
||||||
|
if opData.DoRepository != nil {
|
||||||
|
o.DoRepository = *opData.DoRepository
|
||||||
|
}
|
||||||
|
if opData.Doid != nil {
|
||||||
|
o.Doid = *opData.Doid
|
||||||
|
}
|
||||||
|
if opData.ProducerID != nil {
|
||||||
|
o.ProducerID = *opData.ProducerID
|
||||||
|
}
|
||||||
|
if opData.OpActor != nil {
|
||||||
|
o.OpActor = *opData.OpActor
|
||||||
|
}
|
||||||
|
if opData.RequestBodyHash != nil {
|
||||||
|
hash := *opData.RequestBodyHash
|
||||||
|
o.RequestBodyHash = &hash
|
||||||
|
}
|
||||||
|
if opData.ResponseBodyHash != nil {
|
||||||
|
hash := *opData.ResponseBodyHash
|
||||||
|
o.ResponseBodyHash = &hash
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalBinary 将 Operation 序列化为 CBOR 格式的二进制数据。
|
||||||
|
// 实现 encoding.BinaryMarshaler 接口。
|
||||||
|
// 使用 Canonical CBOR 编码确保序列化结果的一致性,使用缓存机制避免重复序列化。
|
||||||
|
func (o *Operation) MarshalBinary() ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Marshaling operation to CBOR binary",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
if o.binary != nil {
|
||||||
|
log.Debug("Using cached binary data",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return o.binary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
opData := o.toOperationData()
|
||||||
|
|
||||||
|
log.Debug("Marshaling operation data to canonical CBOR",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
binary, err := helpers.MarshalCanonical(opData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to marshal operation to CBOR",
|
||||||
|
"error", err,
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to marshal operation to CBOR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.binary = binary
|
||||||
|
|
||||||
|
log.Debug("Operation marshaled successfully",
|
||||||
|
"opID", o.OpID,
|
||||||
|
"binaryLength", len(binary),
|
||||||
|
)
|
||||||
|
return binary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProducerID 返回 ProducerID,实现 Trustlog 接口。
|
||||||
|
func (o *Operation) GetProducerID() string {
|
||||||
|
return o.ProducerID
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalBinary 从 CBOR 格式的二进制数据反序列化为 Operation。
|
||||||
|
// 实现 encoding.BinaryUnmarshaler 接口。
|
||||||
|
func (o *Operation) UnmarshalBinary(data []byte) error {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Unmarshaling operation from CBOR binary",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Error("Data is empty")
|
||||||
|
return errors.New("data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
opData := &operationData{}
|
||||||
|
|
||||||
|
log.Debug("Unmarshaling operation data from CBOR")
|
||||||
|
if err := helpers.Unmarshal(data, opData); err != nil {
|
||||||
|
log.Error("Failed to unmarshal operation from CBOR",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return fmt.Errorf("failed to unmarshal operation from CBOR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.fromOperationData(opData)
|
||||||
|
|
||||||
|
o.binary = data
|
||||||
|
|
||||||
|
log.Debug("Operation unmarshaled successfully",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 哈希设置方法 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// setBodyHashFlexible 根据输入数据类型计算哈希,支持 string 和 []byte。
|
||||||
|
// 使用固定的 Sha256Simd 算法。
|
||||||
|
// 如果输入为 nil 或空,则目标指针设置为 nil,表示该字段未设置。
|
||||||
|
func (o *Operation) setBodyHashFlexible(data interface{}, target **string) error {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Setting body hash flexible",
|
||||||
|
"opID", o.OpID,
|
||||||
|
"dataType", fmt.Sprintf("%T", data),
|
||||||
|
)
|
||||||
|
if data == nil {
|
||||||
|
log.Debug("Data is nil, setting target to nil")
|
||||||
|
*target = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hashTool := GetHashTool(Sha256Simd)
|
||||||
|
var raw []byte
|
||||||
|
|
||||||
|
switch v := data.(type) {
|
||||||
|
case string:
|
||||||
|
if v == "" {
|
||||||
|
log.Debug("String data is empty, setting target to nil")
|
||||||
|
*target = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw = []byte(v)
|
||||||
|
log.Debug("Converting string to bytes",
|
||||||
|
"stringLength", len(v),
|
||||||
|
)
|
||||||
|
case []byte:
|
||||||
|
if len(v) == 0 {
|
||||||
|
log.Debug("Byte data is empty, setting target to nil")
|
||||||
|
*target = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw = v
|
||||||
|
log.Debug("Using byte data directly",
|
||||||
|
"byteLength", len(v),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
log.Error("Unsupported data type",
|
||||||
|
"dataType", fmt.Sprintf("%T", v),
|
||||||
|
)
|
||||||
|
return fmt.Errorf("unsupported data type %T", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Computing hash for body data",
|
||||||
|
"dataLength", len(raw),
|
||||||
|
)
|
||||||
|
hash, err := hashTool.HashBytes(raw)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to compute hash",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*target = &hash
|
||||||
|
log.Debug("Body hash set successfully",
|
||||||
|
"hash", hash,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestBodyFlexible 设置请求体哈希值。
|
||||||
|
// 支持 string 和 []byte 类型,nil 或空值会将 RequestBodyHash 设置为 nil。
|
||||||
|
func (o *Operation) RequestBodyFlexible(data interface{}) error {
|
||||||
|
return o.setBodyHashFlexible(data, &o.RequestBodyHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyFlexible 设置响应体哈希值。
|
||||||
|
// 支持 string 和 []byte 类型,nil 或空值会将 ResponseBodyHash 设置为 nil。
|
||||||
|
func (o *Operation) ResponseBodyFlexible(data interface{}) error {
|
||||||
|
return o.setBodyHashFlexible(data, &o.ResponseBodyHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 链式调用支持 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// WithRequestBody 设置请求体哈希并返回自身,支持链式调用。
|
||||||
|
func (o *Operation) WithRequestBody(data []byte) *Operation {
|
||||||
|
_ = o.RequestBodyFlexible(data)
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithResponseBody 设置响应体哈希并返回自身,支持链式调用。
|
||||||
|
func (o *Operation) WithResponseBody(data []byte) *Operation {
|
||||||
|
_ = o.ResponseBodyFlexible(data)
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 初始化与验证 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// CheckAndInit 校验并初始化 Operation。
|
||||||
|
// 自动填充缺失字段(OpID、OpActor),执行业务逻辑验证(doid 格式),
|
||||||
|
// 字段非空验证由 validate 标签处理。
|
||||||
|
func (o *Operation) CheckAndInit() error {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Checking and initializing operation",
|
||||||
|
"opSource", o.OpSource,
|
||||||
|
"opType", o.OpType,
|
||||||
|
"doid", o.Doid,
|
||||||
|
)
|
||||||
|
if o.OpID == "" {
|
||||||
|
o.OpID = helpers.NewUUIDv7()
|
||||||
|
log.Debug("Generated new OpID",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.OpActor == "" {
|
||||||
|
o.OpActor = "SYSTEM"
|
||||||
|
log.Debug("Set default OpActor to SYSTEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPrefix := fmt.Sprintf("%s/%s", o.DoPrefix, o.DoRepository)
|
||||||
|
if !strings.HasPrefix(o.Doid, expectedPrefix) {
|
||||||
|
log.Error("Doid format validation failed",
|
||||||
|
"doid", o.Doid,
|
||||||
|
"expectedPrefix", expectedPrefix,
|
||||||
|
)
|
||||||
|
return fmt.Errorf("doid must start with '%s'", expectedPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Validating operation struct")
|
||||||
|
if err := helpers.GetValidator().Struct(o); err != nil {
|
||||||
|
log.Error("Operation validation failed",
|
||||||
|
"error", err,
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Operation checked and initialized successfully",
|
||||||
|
"opID", o.OpID,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
593
api/model/operation_test.go
Normal file
593
api/model/operation_test.go
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOperation_Key(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "test-op-id",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "test-op-id", op.Key())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_CheckAndInit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
op *model.Operation
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid operation",
|
||||||
|
op: &model.Operation{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto generate OpID",
|
||||||
|
op: &model.Operation{
|
||||||
|
OpID: "", // Will be auto-generated
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto set OpActor",
|
||||||
|
op: &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "", // Will be set to "SYSTEM"
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid doid format",
|
||||||
|
op: &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "invalid/123", // Doesn't start with "test/repo"
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
err := tt.op.CheckAndInit()
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
if tt.name == "auto generate OpID" {
|
||||||
|
assert.NotEmpty(t, tt.op.OpID)
|
||||||
|
}
|
||||||
|
if tt.name == "auto set OpActor" {
|
||||||
|
assert.Equal(t, "SYSTEM", tt.op.OpActor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_RequestBodyFlexible(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes",
|
||||||
|
input: []byte("test"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
input: nil,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty bytes",
|
||||||
|
input: []byte{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid type",
|
||||||
|
input: 123,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
op := &model.Operation{}
|
||||||
|
err := op.RequestBodyFlexible(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_ResponseBodyFlexible(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes",
|
||||||
|
input: []byte("test"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
input: nil,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
op := &model.Operation{}
|
||||||
|
err := op.ResponseBodyFlexible(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_WithRequestBody(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{}
|
||||||
|
result := op.WithRequestBody([]byte("test"))
|
||||||
|
assert.Equal(t, op, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_WithResponseBody(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{}
|
||||||
|
result := op.WithResponseBody([]byte("test"))
|
||||||
|
assert.Equal(t, op, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_MarshalUnmarshalBinary(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := original.MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
result := &model.Operation{}
|
||||||
|
err = result.UnmarshalBinary(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Equal(t, original.OpID, result.OpID)
|
||||||
|
assert.Equal(t, original.OpSource, result.OpSource)
|
||||||
|
assert.Equal(t, original.OpType, result.OpType)
|
||||||
|
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||||
|
assert.Equal(t, original.DoRepository, result.DoRepository)
|
||||||
|
assert.Equal(t, original.Doid, result.Doid)
|
||||||
|
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||||
|
assert.Equal(t, original.OpActor, result.OpActor)
|
||||||
|
// 验证纳秒精度被保留
|
||||||
|
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||||
|
"时间戳的纳秒精度应该被保留")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_MarshalBinary_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
}
|
||||||
|
// MarshalBinary should succeed even without CheckAndInit
|
||||||
|
// It just serializes the data
|
||||||
|
data, err := op.MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_UnmarshalBinary_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{}
|
||||||
|
err := op.UnmarshalBinary([]byte{})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_GetProducerID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
ProducerID: "producer-123",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "producer-123", op.GetProducerID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_DoHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
hashData, err := op.DoHash(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, hashData)
|
||||||
|
assert.Equal(t, op.OpID, hashData.Key())
|
||||||
|
assert.NotEmpty(t, hashData.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperationHashData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// OperationHashData is created through DoHash, test it indirectly
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
hashData, err := op.DoHash(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, hashData)
|
||||||
|
assert.Equal(t, "op-123", hashData.Key())
|
||||||
|
assert.NotEmpty(t, hashData.Hash())
|
||||||
|
assert.Equal(t, model.Sha256Simd, hashData.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_UnmarshalBinary_InvalidData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{}
|
||||||
|
err := op.UnmarshalBinary([]byte("invalid-cbor-data"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to unmarshal operation from CBOR")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_MarshalTrustlog_EmptyProducerID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create an operation with empty ProducerID
|
||||||
|
// MarshalBinary will fail validation, but MarshalTrustlog checks ProducerID first
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "", // Empty ProducerID
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||||
|
_, err := model.MarshalTrustlog(op, config)
|
||||||
|
// MarshalTrustlog checks ProducerID before calling MarshalBinary
|
||||||
|
require.Error(t, err)
|
||||||
|
// Error could be from ProducerID check or MarshalBinary validation
|
||||||
|
assert.True(t,
|
||||||
|
err.Error() == "producerID cannot be empty" ||
|
||||||
|
strings.Contains(err.Error(), "ProducerID") ||
|
||||||
|
strings.Contains(err.Error(), "producerID"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOperation_MarshalTrustlog_NilSigner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
op := &model.Operation{
|
||||||
|
OpID: "op-123",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := op.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := model.EnvelopeConfig{Signer: nil}
|
||||||
|
_, err = model.MarshalTrustlog(op, config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "signer is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOpTypesBySource(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source model.Source
|
||||||
|
wantTypes []model.Type
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IRP操作类型",
|
||||||
|
source: model.OpSourceIRP,
|
||||||
|
wantTypes: []model.Type{
|
||||||
|
model.OpTypeOCCreateHandle,
|
||||||
|
model.OpTypeOCDeleteHandle,
|
||||||
|
model.OpTypeOCAddValue,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DOIP操作类型",
|
||||||
|
source: model.OpSourceDOIP,
|
||||||
|
wantTypes: []model.Type{
|
||||||
|
model.OpTypeHello,
|
||||||
|
model.OpTypeCreate,
|
||||||
|
model.OpTypeDelete,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
opTypes := model.GetOpTypesBySource(tt.source)
|
||||||
|
assert.NotNil(t, opTypes)
|
||||||
|
// Verify expected types are included
|
||||||
|
for _, expectedType := range tt.wantTypes {
|
||||||
|
assert.Contains(t, opTypes, expectedType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidOpType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source model.Source
|
||||||
|
opType model.Type
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IRP有效操作类型",
|
||||||
|
source: model.OpSourceIRP,
|
||||||
|
opType: model.OpTypeOCCreateHandle,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IRP无效操作类型",
|
||||||
|
source: model.OpSourceIRP,
|
||||||
|
opType: model.OpTypeHello,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DOIP有效操作类型",
|
||||||
|
source: model.OpSourceDOIP,
|
||||||
|
opType: model.OpTypeHello,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DOIP无效操作类型",
|
||||||
|
source: model.OpSourceDOIP,
|
||||||
|
opType: model.OpTypeOCCreateHandle,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "未知来源和类型",
|
||||||
|
source: model.Source("unknown"),
|
||||||
|
opType: model.Type("unknown"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
result := model.IsValidOpType(tt.source, tt.opType)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFullOperation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opSource model.Source
|
||||||
|
opType model.Type
|
||||||
|
doPrefix string
|
||||||
|
doRepository string
|
||||||
|
doid string
|
||||||
|
producerID string
|
||||||
|
opActor string
|
||||||
|
requestBody interface{}
|
||||||
|
responseBody interface{}
|
||||||
|
timestamp time.Time
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "成功创建完整操作",
|
||||||
|
opSource: model.OpSourceIRP,
|
||||||
|
opType: model.OpTypeOCCreateHandle,
|
||||||
|
doPrefix: "test",
|
||||||
|
doRepository: "repo",
|
||||||
|
doid: "test/repo/123",
|
||||||
|
producerID: "producer-1",
|
||||||
|
opActor: "actor-1",
|
||||||
|
requestBody: []byte(`{"key": "value"}`),
|
||||||
|
responseBody: []byte(`{"status": "ok"}`),
|
||||||
|
timestamp: time.Now(),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空请求体和响应体",
|
||||||
|
opSource: model.OpSourceIRP,
|
||||||
|
opType: model.OpTypeOCCreateHandle,
|
||||||
|
doPrefix: "test",
|
||||||
|
doRepository: "repo",
|
||||||
|
doid: "test/repo/123",
|
||||||
|
producerID: "producer-1",
|
||||||
|
opActor: "actor-1",
|
||||||
|
requestBody: nil,
|
||||||
|
responseBody: nil,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "字符串类型的请求体",
|
||||||
|
opSource: model.OpSourceIRP,
|
||||||
|
opType: model.OpTypeOCCreateHandle,
|
||||||
|
doPrefix: "test",
|
||||||
|
doRepository: "repo",
|
||||||
|
doid: "test/repo/123",
|
||||||
|
producerID: "producer-1",
|
||||||
|
opActor: "actor-1",
|
||||||
|
requestBody: "string body",
|
||||||
|
responseBody: "string response",
|
||||||
|
timestamp: time.Now(),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
op, err := model.NewFullOperation(
|
||||||
|
tt.opSource,
|
||||||
|
tt.opType,
|
||||||
|
tt.doPrefix,
|
||||||
|
tt.doRepository,
|
||||||
|
tt.doid,
|
||||||
|
tt.producerID,
|
||||||
|
tt.opActor,
|
||||||
|
tt.requestBody,
|
||||||
|
tt.responseBody,
|
||||||
|
tt.timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, op)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, op)
|
||||||
|
assert.Equal(t, tt.opSource, op.OpSource)
|
||||||
|
assert.Equal(t, tt.opType, op.OpType)
|
||||||
|
assert.Equal(t, tt.doPrefix, op.DoPrefix)
|
||||||
|
assert.Equal(t, tt.doRepository, op.DoRepository)
|
||||||
|
assert.Equal(t, tt.doid, op.Doid)
|
||||||
|
assert.Equal(t, tt.producerID, op.ProducerID)
|
||||||
|
assert.Equal(t, tt.opActor, op.OpActor)
|
||||||
|
assert.NotEmpty(t, op.OpID) // Should be auto-generated
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
56
api/model/operation_timestamp_test.go
Normal file
56
api/model/operation_timestamp_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestOperation_TimestampNanosecondPrecision 验证 Operation 的时间戳在 CBOR 序列化/反序列化后能保留纳秒精度
|
||||||
|
func TestOperation_TimestampNanosecondPrecision(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 创建一个包含纳秒精度的时间戳
|
||||||
|
timestamp := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||||
|
|
||||||
|
original := &model.Operation{
|
||||||
|
OpID: "op-nanosecond-test",
|
||||||
|
Timestamp: timestamp,
|
||||||
|
OpSource: model.OpSourceIRP,
|
||||||
|
OpType: model.OpTypeOCCreateHandle,
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
OpActor: "actor-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := original.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Original timestamp: %v", original.Timestamp)
|
||||||
|
t.Logf("Original nanoseconds: %d", original.Timestamp.Nanosecond())
|
||||||
|
|
||||||
|
// 序列化
|
||||||
|
data, err := original.MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// 反序列化
|
||||||
|
result := &model.Operation{}
|
||||||
|
err = result.UnmarshalBinary(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Decoded timestamp: %v", result.Timestamp)
|
||||||
|
t.Logf("Decoded nanoseconds: %d", result.Timestamp.Nanosecond())
|
||||||
|
|
||||||
|
// 验证纳秒精度被完整保留
|
||||||
|
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||||
|
"时间戳的纳秒精度应该被完整保留")
|
||||||
|
assert.Equal(t, original.Timestamp.Nanosecond(), result.Timestamp.Nanosecond(),
|
||||||
|
"纳秒部分应该相等")
|
||||||
|
}
|
||||||
146
api/model/proof.go
Normal file
146
api/model/proof.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MerkleTreeProofItem 表示Merkle树证明项.
|
||||||
|
type MerkleTreeProofItem struct {
|
||||||
|
Floor uint32 // 层级
|
||||||
|
Hash string // 哈希值
|
||||||
|
Left bool // 是否为左节点
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proof 表示取证证明.
|
||||||
|
type Proof struct {
|
||||||
|
ColItems []*MerkleTreeProofItem // 集合项证明
|
||||||
|
RawItems []*MerkleTreeProofItem // 原始项证明
|
||||||
|
ColRootItem []*MerkleTreeProofItem // 集合根项证明
|
||||||
|
RawRootItem []*MerkleTreeProofItem // 原始根项证明
|
||||||
|
Sign string // 签名
|
||||||
|
Version string // 版本号
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProofFromProtobuf 将protobuf的Proof转换为model.Proof.
|
||||||
|
func ProofFromProtobuf(pbProof *pb.Proof) *Proof {
|
||||||
|
if pbProof == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proof := &Proof{
|
||||||
|
Sign: pbProof.GetSign(),
|
||||||
|
Version: pbProof.GetVersion(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 ColItems
|
||||||
|
if pbColItems := pbProof.GetColItems(); len(pbColItems) > 0 {
|
||||||
|
proof.ColItems = make([]*MerkleTreeProofItem, 0, len(pbColItems))
|
||||||
|
for _, item := range pbColItems {
|
||||||
|
proof.ColItems = append(proof.ColItems, &MerkleTreeProofItem{
|
||||||
|
Floor: item.GetFloor(),
|
||||||
|
Hash: item.GetHash(),
|
||||||
|
Left: item.GetLeft(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 RawItems
|
||||||
|
if pbRawItems := pbProof.GetRawItems(); len(pbRawItems) > 0 {
|
||||||
|
proof.RawItems = make([]*MerkleTreeProofItem, 0, len(pbRawItems))
|
||||||
|
for _, item := range pbRawItems {
|
||||||
|
proof.RawItems = append(proof.RawItems, &MerkleTreeProofItem{
|
||||||
|
Floor: item.GetFloor(),
|
||||||
|
Hash: item.GetHash(),
|
||||||
|
Left: item.GetLeft(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 ColRootItem
|
||||||
|
if pbColRootItem := pbProof.GetColRootItem(); len(pbColRootItem) > 0 {
|
||||||
|
proof.ColRootItem = make([]*MerkleTreeProofItem, 0, len(pbColRootItem))
|
||||||
|
for _, item := range pbColRootItem {
|
||||||
|
proof.ColRootItem = append(proof.ColRootItem, &MerkleTreeProofItem{
|
||||||
|
Floor: item.GetFloor(),
|
||||||
|
Hash: item.GetHash(),
|
||||||
|
Left: item.GetLeft(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 RawRootItem
|
||||||
|
if pbRawRootItem := pbProof.GetRawRootItem(); len(pbRawRootItem) > 0 {
|
||||||
|
proof.RawRootItem = make([]*MerkleTreeProofItem, 0, len(pbRawRootItem))
|
||||||
|
for _, item := range pbRawRootItem {
|
||||||
|
proof.RawRootItem = append(proof.RawRootItem, &MerkleTreeProofItem{
|
||||||
|
Floor: item.GetFloor(),
|
||||||
|
Hash: item.GetHash(),
|
||||||
|
Left: item.GetLeft(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return proof
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProofToProtobuf 将model.Proof转换为protobuf的Proof.
|
||||||
|
func ProofToProtobuf(proof *Proof) *pb.Proof {
|
||||||
|
if proof == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
Sign: proof.Sign,
|
||||||
|
Version: proof.Version,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 ColItems
|
||||||
|
if len(proof.ColItems) > 0 {
|
||||||
|
pbProof.ColItems = make([]*pb.MerkleTreeProofItem, 0, len(proof.ColItems))
|
||||||
|
for _, item := range proof.ColItems {
|
||||||
|
pbProof.ColItems = append(pbProof.ColItems, &pb.MerkleTreeProofItem{
|
||||||
|
Floor: item.Floor,
|
||||||
|
Hash: item.Hash,
|
||||||
|
Left: item.Left,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 RawItems
|
||||||
|
if len(proof.RawItems) > 0 {
|
||||||
|
pbProof.RawItems = make([]*pb.MerkleTreeProofItem, 0, len(proof.RawItems))
|
||||||
|
for _, item := range proof.RawItems {
|
||||||
|
pbProof.RawItems = append(pbProof.RawItems, &pb.MerkleTreeProofItem{
|
||||||
|
Floor: item.Floor,
|
||||||
|
Hash: item.Hash,
|
||||||
|
Left: item.Left,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 ColRootItem
|
||||||
|
if len(proof.ColRootItem) > 0 {
|
||||||
|
pbProof.ColRootItem = make([]*pb.MerkleTreeProofItem, 0, len(proof.ColRootItem))
|
||||||
|
for _, item := range proof.ColRootItem {
|
||||||
|
pbProof.ColRootItem = append(pbProof.ColRootItem, &pb.MerkleTreeProofItem{
|
||||||
|
Floor: item.Floor,
|
||||||
|
Hash: item.Hash,
|
||||||
|
Left: item.Left,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 RawRootItem
|
||||||
|
if len(proof.RawRootItem) > 0 {
|
||||||
|
pbProof.RawRootItem = make([]*pb.MerkleTreeProofItem, 0, len(proof.RawRootItem))
|
||||||
|
for _, item := range proof.RawRootItem {
|
||||||
|
pbProof.RawRootItem = append(pbProof.RawRootItem, &pb.MerkleTreeProofItem{
|
||||||
|
Floor: item.Floor,
|
||||||
|
Hash: item.Hash,
|
||||||
|
Left: item.Left,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pbProof
|
||||||
|
}
|
||||||
349
api/model/proof_test.go
Normal file
349
api/model/proof_test.go
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := model.ProofFromProtobuf(nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Empty(t, result.Sign)
|
||||||
|
assert.Empty(t, result.Version)
|
||||||
|
assert.Nil(t, result.ColItems)
|
||||||
|
assert.Nil(t, result.RawItems)
|
||||||
|
assert.Nil(t, result.ColRootItem)
|
||||||
|
assert.Nil(t, result.RawRootItem)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_WithSign(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
Sign: "test-signature",
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "test-signature", result.Sign)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_WithVersion(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
Version: "v1.0.0",
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "v1.0.0", result.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_WithColItems(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
ColItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "hash1", Left: true},
|
||||||
|
{Floor: 2, Hash: "hash2", Left: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.ColItems, 2)
|
||||||
|
assert.Equal(t, uint32(1), result.ColItems[0].Floor)
|
||||||
|
assert.Equal(t, "hash1", result.ColItems[0].Hash)
|
||||||
|
assert.True(t, result.ColItems[0].Left)
|
||||||
|
assert.Equal(t, uint32(2), result.ColItems[1].Floor)
|
||||||
|
assert.Equal(t, "hash2", result.ColItems[1].Hash)
|
||||||
|
assert.False(t, result.ColItems[1].Left)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_WithRawItems(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
RawItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 3, Hash: "hash3", Left: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.RawItems, 1)
|
||||||
|
assert.Equal(t, uint32(3), result.RawItems[0].Floor)
|
||||||
|
assert.Equal(t, "hash3", result.RawItems[0].Hash)
|
||||||
|
assert.True(t, result.RawItems[0].Left)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_WithColRootItem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
ColRootItem: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 4, Hash: "hash4", Left: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.ColRootItem, 1)
|
||||||
|
assert.Equal(t, uint32(4), result.ColRootItem[0].Floor)
|
||||||
|
assert.Equal(t, "hash4", result.ColRootItem[0].Hash)
|
||||||
|
assert.False(t, result.ColRootItem[0].Left)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_WithRawRootItem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
RawRootItem: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 5, Hash: "hash5", Left: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.RawRootItem, 1)
|
||||||
|
assert.Equal(t, uint32(5), result.RawRootItem[0].Floor)
|
||||||
|
assert.Equal(t, "hash5", result.RawRootItem[0].Hash)
|
||||||
|
assert.True(t, result.RawRootItem[0].Left)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofFromProtobuf_Full(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pbProof := &pb.Proof{
|
||||||
|
Sign: "full-signature",
|
||||||
|
Version: "v1.0.0",
|
||||||
|
ColItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "col1", Left: true},
|
||||||
|
},
|
||||||
|
RawItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 2, Hash: "raw1", Left: false},
|
||||||
|
},
|
||||||
|
ColRootItem: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 3, Hash: "colroot1", Left: true},
|
||||||
|
},
|
||||||
|
RawRootItem: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 4, Hash: "rawroot1", Left: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofFromProtobuf(pbProof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "full-signature", result.Sign)
|
||||||
|
assert.Equal(t, "v1.0.0", result.Version)
|
||||||
|
assert.Len(t, result.ColItems, 1)
|
||||||
|
assert.Len(t, result.RawItems, 1)
|
||||||
|
assert.Len(t, result.ColRootItem, 1)
|
||||||
|
assert.Len(t, result.RawRootItem, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := model.ProofToProtobuf(nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Empty(t, result.GetSign())
|
||||||
|
assert.Empty(t, result.GetVersion())
|
||||||
|
assert.Nil(t, result.GetColItems())
|
||||||
|
assert.Nil(t, result.GetRawItems())
|
||||||
|
assert.Nil(t, result.GetColRootItem())
|
||||||
|
assert.Nil(t, result.GetRawRootItem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_WithSign(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
Sign: "test-signature",
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "test-signature", result.GetSign())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_WithVersion(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
Version: "v1.0.0",
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "v1.0.0", result.GetVersion())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_WithColItems(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
ColItems: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "hash1", Left: true},
|
||||||
|
{Floor: 2, Hash: "hash2", Left: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.GetColItems(), 2)
|
||||||
|
assert.Equal(t, uint32(1), result.GetColItems()[0].GetFloor())
|
||||||
|
assert.Equal(t, "hash1", result.GetColItems()[0].GetHash())
|
||||||
|
assert.True(t, result.GetColItems()[0].GetLeft())
|
||||||
|
assert.Equal(t, uint32(2), result.GetColItems()[1].GetFloor())
|
||||||
|
assert.Equal(t, "hash2", result.GetColItems()[1].GetHash())
|
||||||
|
assert.False(t, result.GetColItems()[1].GetLeft())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_WithRawItems(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
RawItems: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 3, Hash: "hash3", Left: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.GetRawItems(), 1)
|
||||||
|
assert.Equal(t, uint32(3), result.GetRawItems()[0].GetFloor())
|
||||||
|
assert.Equal(t, "hash3", result.GetRawItems()[0].GetHash())
|
||||||
|
assert.True(t, result.GetRawItems()[0].GetLeft())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_WithColRootItem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
ColRootItem: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 4, Hash: "hash4", Left: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.GetColRootItem(), 1)
|
||||||
|
assert.Equal(t, uint32(4), result.GetColRootItem()[0].GetFloor())
|
||||||
|
assert.Equal(t, "hash4", result.GetColRootItem()[0].GetHash())
|
||||||
|
assert.False(t, result.GetColRootItem()[0].GetLeft())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_WithRawRootItem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
RawRootItem: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 5, Hash: "hash5", Left: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Len(t, result.GetRawRootItem(), 1)
|
||||||
|
assert.Equal(t, uint32(5), result.GetRawRootItem()[0].GetFloor())
|
||||||
|
assert.Equal(t, "hash5", result.GetRawRootItem()[0].GetHash())
|
||||||
|
assert.True(t, result.GetRawRootItem()[0].GetLeft())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofToProtobuf_Full(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
proof := &model.Proof{
|
||||||
|
Sign: "full-signature",
|
||||||
|
Version: "v1.0.0",
|
||||||
|
ColItems: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "col1", Left: true},
|
||||||
|
},
|
||||||
|
RawItems: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 2, Hash: "raw1", Left: false},
|
||||||
|
},
|
||||||
|
ColRootItem: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 3, Hash: "colroot1", Left: true},
|
||||||
|
},
|
||||||
|
RawRootItem: []*model.MerkleTreeProofItem{
|
||||||
|
{Floor: 4, Hash: "rawroot1", Left: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := model.ProofToProtobuf(proof)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "full-signature", result.GetSign())
|
||||||
|
assert.Equal(t, "v1.0.0", result.GetVersion())
|
||||||
|
assert.Len(t, result.GetColItems(), 1)
|
||||||
|
assert.Len(t, result.GetRawItems(), 1)
|
||||||
|
assert.Len(t, result.GetColRootItem(), 1)
|
||||||
|
assert.Len(t, result.GetRawRootItem(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofRoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := &pb.Proof{
|
||||||
|
Sign: "round-trip-signature",
|
||||||
|
Version: "v1.0.0",
|
||||||
|
ColItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "col1", Left: true},
|
||||||
|
{Floor: 2, Hash: "col2", Left: false},
|
||||||
|
},
|
||||||
|
RawItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 3, Hash: "raw1", Left: true},
|
||||||
|
},
|
||||||
|
ColRootItem: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 4, Hash: "colroot1", Left: false},
|
||||||
|
},
|
||||||
|
RawRootItem: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 5, Hash: "rawroot1", Left: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to model
|
||||||
|
modelProof := model.ProofFromProtobuf(original)
|
||||||
|
require.NotNil(t, modelProof)
|
||||||
|
|
||||||
|
// Convert back to protobuf
|
||||||
|
pbProof := model.ProofToProtobuf(modelProof)
|
||||||
|
require.NotNil(t, pbProof)
|
||||||
|
|
||||||
|
// Verify round trip
|
||||||
|
assert.Equal(t, original.GetSign(), pbProof.GetSign())
|
||||||
|
assert.Equal(t, original.GetVersion(), pbProof.GetVersion())
|
||||||
|
assert.Len(t, pbProof.GetColItems(), 2)
|
||||||
|
assert.Len(t, pbProof.GetRawItems(), 1)
|
||||||
|
assert.Len(t, pbProof.GetColRootItem(), 1)
|
||||||
|
assert.Len(t, pbProof.GetRawRootItem(), 1)
|
||||||
|
|
||||||
|
assert.Equal(t, original.GetColItems()[0].GetFloor(), pbProof.GetColItems()[0].GetFloor())
|
||||||
|
assert.Equal(t, original.GetColItems()[0].GetHash(), pbProof.GetColItems()[0].GetHash())
|
||||||
|
assert.Equal(t, original.GetColItems()[0].GetLeft(), pbProof.GetColItems()[0].GetLeft())
|
||||||
|
}
|
||||||
348
api/model/record.go
Normal file
348
api/model/record.go
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Record 表示一条记录。
|
||||||
|
// 用于记录系统中的操作行为,包含记录标识、节点前缀、操作者信息等。
|
||||||
|
type Record struct {
|
||||||
|
ID string `json:"id" validate:"required,max=128"`
|
||||||
|
DoPrefix string `json:"doPrefix" validate:"max=512"`
|
||||||
|
ProducerID string `json:"producerId" validate:"required,max=512"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Operator string `json:"operator" validate:"max=64"`
|
||||||
|
Extra []byte `json:"extra" validate:"max=512"`
|
||||||
|
RCType string `json:"type" validate:"max=64"`
|
||||||
|
binary []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 构造函数 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// NewFullRecord 创建包含所有字段的完整 Record。
|
||||||
|
// 自动完成字段校验,确保创建的 Record 是完整且有效的。
|
||||||
|
func NewFullRecord(
|
||||||
|
doPrefix string,
|
||||||
|
producerID string,
|
||||||
|
timestamp time.Time,
|
||||||
|
operator string,
|
||||||
|
extra []byte,
|
||||||
|
rcType string,
|
||||||
|
) (*Record, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new full record",
|
||||||
|
"doPrefix", doPrefix,
|
||||||
|
"producerID", producerID,
|
||||||
|
"operator", operator,
|
||||||
|
"rcType", rcType,
|
||||||
|
"extraLength", len(extra),
|
||||||
|
)
|
||||||
|
record := &Record{
|
||||||
|
DoPrefix: doPrefix,
|
||||||
|
ProducerID: producerID,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
Operator: operator,
|
||||||
|
Extra: extra,
|
||||||
|
RCType: rcType,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Checking and initializing record")
|
||||||
|
if err := record.CheckAndInit(); err != nil {
|
||||||
|
log.Error("Failed to check and init record",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Full record created successfully",
|
||||||
|
"recordID", record.ID,
|
||||||
|
)
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 接口实现 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
func (r *Record) Key() string {
|
||||||
|
return r.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordHashData 实现 HashData 接口,用于存储 Record 的哈希计算结果。
|
||||||
|
type RecordHashData struct {
|
||||||
|
key string
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RecordHashData) Key() string {
|
||||||
|
return r.key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RecordHashData) Hash() string {
|
||||||
|
return r.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RecordHashData) Type() HashType {
|
||||||
|
return Sha256Simd
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoHash 计算 Record 的整体哈希值,用于数据完整性验证。
|
||||||
|
// 哈希基于序列化后的二进制数据计算,确保记录数据的不可篡改性。
|
||||||
|
func (r *Record) DoHash(_ context.Context) (HashData, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Computing hash for record",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
hashTool := GetHashTool(Sha256Simd)
|
||||||
|
binary, err := r.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to marshal record for hash",
|
||||||
|
"error", err,
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to marshal record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Computing hash bytes",
|
||||||
|
"recordID", r.ID,
|
||||||
|
"binaryLength", len(binary),
|
||||||
|
)
|
||||||
|
hash, err := hashTool.HashBytes(binary)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to compute hash",
|
||||||
|
"error", err,
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to compute hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Hash computed successfully",
|
||||||
|
"recordID", r.ID,
|
||||||
|
"hash", hash,
|
||||||
|
)
|
||||||
|
return RecordHashData{
|
||||||
|
key: r.ID,
|
||||||
|
hash: hash,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== CBOR 序列化相关 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// recordData 用于 CBOR 序列化/反序列化的中间结构。
|
||||||
|
// 排除缓存字段,仅包含可序列化的数据字段。
|
||||||
|
type recordData struct {
|
||||||
|
ID *string `cbor:"id"`
|
||||||
|
DoPrefix *string `cbor:"doPrefix"`
|
||||||
|
ProducerID *string `cbor:"producerId"`
|
||||||
|
Timestamp *time.Time `cbor:"timestamp"`
|
||||||
|
Operator *string `cbor:"operator"`
|
||||||
|
Extra []byte `cbor:"extra"`
|
||||||
|
RCType *string `cbor:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// toRecordData 将 Record 转换为 recordData,用于序列化。
|
||||||
|
func (r *Record) toRecordData() *recordData {
|
||||||
|
return &recordData{
|
||||||
|
ID: &r.ID,
|
||||||
|
DoPrefix: &r.DoPrefix,
|
||||||
|
ProducerID: &r.ProducerID,
|
||||||
|
Timestamp: &r.Timestamp,
|
||||||
|
Operator: &r.Operator,
|
||||||
|
Extra: r.Extra,
|
||||||
|
RCType: &r.RCType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fromRecordData 从 recordData 填充 Record,用于反序列化。
|
||||||
|
func (r *Record) fromRecordData(recData *recordData) {
|
||||||
|
if recData == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if recData.ID != nil {
|
||||||
|
r.ID = *recData.ID
|
||||||
|
}
|
||||||
|
if recData.DoPrefix != nil {
|
||||||
|
r.DoPrefix = *recData.DoPrefix
|
||||||
|
}
|
||||||
|
if recData.ProducerID != nil {
|
||||||
|
r.ProducerID = *recData.ProducerID
|
||||||
|
}
|
||||||
|
if recData.Timestamp != nil {
|
||||||
|
r.Timestamp = *recData.Timestamp
|
||||||
|
}
|
||||||
|
if recData.Operator != nil {
|
||||||
|
r.Operator = *recData.Operator
|
||||||
|
}
|
||||||
|
if recData.Extra != nil {
|
||||||
|
r.Extra = recData.Extra
|
||||||
|
}
|
||||||
|
if recData.RCType != nil {
|
||||||
|
r.RCType = *recData.RCType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalBinary 将 Record 序列化为 CBOR 格式的二进制数据。
|
||||||
|
// 实现 encoding.BinaryMarshaler 接口。
|
||||||
|
// 使用 Canonical CBOR 编码确保序列化结果的一致性,使用缓存机制避免重复序列化。
|
||||||
|
func (r *Record) MarshalBinary() ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Marshaling record to CBOR binary",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
if r.binary != nil {
|
||||||
|
log.Debug("Using cached binary data",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return r.binary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
recData := r.toRecordData()
|
||||||
|
|
||||||
|
log.Debug("Marshaling record data to canonical CBOR",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
binary, err := helpers.MarshalCanonical(recData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to marshal record to CBOR",
|
||||||
|
"error", err,
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to marshal record to CBOR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.binary = binary
|
||||||
|
|
||||||
|
log.Debug("Record marshaled successfully",
|
||||||
|
"recordID", r.ID,
|
||||||
|
"binaryLength", len(binary),
|
||||||
|
)
|
||||||
|
return binary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalBinary 从 CBOR 格式的二进制数据反序列化为 Record。
|
||||||
|
// 实现 encoding.BinaryUnmarshaler 接口。
|
||||||
|
func (r *Record) UnmarshalBinary(data []byte) error {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Unmarshaling record from CBOR binary",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Error("Data is empty")
|
||||||
|
return errors.New("data is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
recData := &recordData{}
|
||||||
|
|
||||||
|
log.Debug("Unmarshaling record data from CBOR")
|
||||||
|
if err := helpers.Unmarshal(data, recData); err != nil {
|
||||||
|
log.Error("Failed to unmarshal record from CBOR",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return fmt.Errorf("failed to unmarshal record from CBOR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.fromRecordData(recData)
|
||||||
|
|
||||||
|
r.binary = data
|
||||||
|
|
||||||
|
log.Debug("Record unmarshaled successfully",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDoPrefix 实现 DoPrefixExtractor 接口,返回节点前缀。
|
||||||
|
func (r *Record) GetDoPrefix() string {
|
||||||
|
return r.DoPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProducerID 返回 ProducerID,实现 Trustlog 接口。
|
||||||
|
func (r *Record) GetProducerID() string {
|
||||||
|
return r.ProducerID
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 初始化与验证 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// CheckAndInit 校验并初始化 Record。
|
||||||
|
// 自动填充缺失字段(ID),字段非空验证由 validate 标签处理。
|
||||||
|
func (r *Record) CheckAndInit() error {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Checking and initializing record",
|
||||||
|
"producerID", r.ProducerID,
|
||||||
|
"doPrefix", r.DoPrefix,
|
||||||
|
)
|
||||||
|
if r.ID == "" {
|
||||||
|
r.ID = helpers.NewUUIDv7()
|
||||||
|
log.Debug("Generated new record ID",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Timestamp.IsZero() {
|
||||||
|
r.Timestamp = time.Now()
|
||||||
|
log.Debug("Set default timestamp",
|
||||||
|
"timestamp", r.Timestamp,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Validating record struct")
|
||||||
|
if err := helpers.GetValidator().Struct(r); err != nil {
|
||||||
|
log.Error("Record validation failed",
|
||||||
|
"error", err,
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Record checked and initialized successfully",
|
||||||
|
"recordID", r.ID,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ===== 链式调用支持 =====
|
||||||
|
//
|
||||||
|
|
||||||
|
// WithDoPrefix 设置 DoPrefix 并返回自身,支持链式调用。
|
||||||
|
func (r *Record) WithDoPrefix(doPrefix string) *Record {
|
||||||
|
r.DoPrefix = doPrefix
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimestamp 设置 Timestamp 并返回自身,支持链式调用。
|
||||||
|
func (r *Record) WithTimestamp(timestamp time.Time) *Record {
|
||||||
|
r.Timestamp = timestamp
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOperator 设置 Operator 并返回自身,支持链式调用。
|
||||||
|
func (r *Record) WithOperator(operator string) *Record {
|
||||||
|
r.Operator = operator
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExtra 设置 Extra 并返回自身,支持链式调用。
|
||||||
|
func (r *Record) WithExtra(extra []byte) *Record {
|
||||||
|
r.Extra = extra
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRCType 设置 RCType 并返回自身,支持链式调用。
|
||||||
|
func (r *Record) WithRCType(rcType string) *Record {
|
||||||
|
r.RCType = rcType
|
||||||
|
return r
|
||||||
|
}
|
||||||
321
api/model/record_test.go
Normal file
321
api/model/record_test.go
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecord_Key(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
ID: "test-record-id",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "test-record-id", rec.Key())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFullRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
rec, err := model.NewFullRecord(
|
||||||
|
"test-prefix",
|
||||||
|
"producer-1",
|
||||||
|
now,
|
||||||
|
"operator-1",
|
||||||
|
[]byte("extra"),
|
||||||
|
"log",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, rec)
|
||||||
|
assert.NotEmpty(t, rec.ID)
|
||||||
|
assert.Equal(t, "test-prefix", rec.DoPrefix)
|
||||||
|
assert.Equal(t, "producer-1", rec.ProducerID)
|
||||||
|
assert.Equal(t, now.Unix(), rec.Timestamp.Unix())
|
||||||
|
assert.Equal(t, "operator-1", rec.Operator)
|
||||||
|
assert.Equal(t, []byte("extra"), rec.Extra)
|
||||||
|
assert.Equal(t, "log", rec.RCType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFullRecord_Invalid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
// Missing required ProducerID
|
||||||
|
rec, err := model.NewFullRecord(
|
||||||
|
"test-prefix",
|
||||||
|
"", // Empty ProducerID
|
||||||
|
now,
|
||||||
|
"operator-1",
|
||||||
|
[]byte("extra"),
|
||||||
|
"log",
|
||||||
|
)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, rec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_CheckAndInit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rec *model.Record
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid record",
|
||||||
|
rec: &model.Record{
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto generate ID",
|
||||||
|
rec: &model.Record{
|
||||||
|
ID: "", // Will be auto-generated
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing ProducerID",
|
||||||
|
rec: &model.Record{
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "", // Required field
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
err := tt.rec.CheckAndInit()
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
if tt.name == "auto generate ID" {
|
||||||
|
assert.NotEmpty(t, tt.rec.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_MarshalUnmarshalBinary(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := original.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := original.MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
result := &model.Record{}
|
||||||
|
err = result.UnmarshalBinary(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Equal(t, original.ID, result.ID)
|
||||||
|
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||||
|
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||||
|
assert.Equal(t, original.Timestamp.Unix(), result.Timestamp.Unix())
|
||||||
|
// 验证纳秒精度被保留
|
||||||
|
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||||
|
"时间戳的纳秒精度应该被保留")
|
||||||
|
assert.Equal(t, original.Operator, result.Operator)
|
||||||
|
assert.Equal(t, original.Extra, result.Extra)
|
||||||
|
assert.Equal(t, original.RCType, result.RCType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_MarshalBinary_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
// MarshalBinary should succeed even without CheckAndInit
|
||||||
|
// It just serializes the data
|
||||||
|
data, err := rec.MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_UnmarshalBinary_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
err := rec.UnmarshalBinary([]byte{})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_DoHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rec.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
hashData, err := rec.DoHash(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, hashData)
|
||||||
|
assert.Equal(t, rec.ID, hashData.Key())
|
||||||
|
assert.NotEmpty(t, hashData.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordHashData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// RecordHashData is created through DoHash, test it indirectly
|
||||||
|
rec := &model.Record{
|
||||||
|
ID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := rec.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
hashData, err := rec.DoHash(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, hashData)
|
||||||
|
assert.Equal(t, "rec-123", hashData.Key())
|
||||||
|
assert.NotEmpty(t, hashData.Hash())
|
||||||
|
assert.Equal(t, model.Sha256Simd, hashData.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_GetProducerID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
ProducerID: "producer-123",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "producer-123", rec.GetProducerID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_GetDoPrefix(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{
|
||||||
|
DoPrefix: "test-prefix",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "test-prefix", rec.GetDoPrefix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_WithDoPrefix(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
result := rec.WithDoPrefix("test-prefix")
|
||||||
|
assert.Equal(t, rec, result)
|
||||||
|
assert.Equal(t, "test-prefix", rec.DoPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_WithTimestamp(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
now := time.Now()
|
||||||
|
result := rec.WithTimestamp(now)
|
||||||
|
assert.Equal(t, rec, result)
|
||||||
|
assert.Equal(t, now.Unix(), rec.Timestamp.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_WithOperator(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
result := rec.WithOperator("operator-1")
|
||||||
|
assert.Equal(t, rec, result)
|
||||||
|
assert.Equal(t, "operator-1", rec.Operator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_WithExtra(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
extra := []byte("extra-data")
|
||||||
|
result := rec.WithExtra(extra)
|
||||||
|
assert.Equal(t, rec, result)
|
||||||
|
assert.Equal(t, extra, rec.Extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_WithRCType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
result := rec.WithRCType("log")
|
||||||
|
assert.Equal(t, rec, result)
|
||||||
|
assert.Equal(t, "log", rec.RCType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord_ChainedMethods(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rec := &model.Record{}
|
||||||
|
now := time.Now()
|
||||||
|
result := rec.
|
||||||
|
WithDoPrefix("prefix").
|
||||||
|
WithTimestamp(now).
|
||||||
|
WithOperator("operator").
|
||||||
|
WithExtra([]byte("extra")).
|
||||||
|
WithRCType("log")
|
||||||
|
|
||||||
|
assert.Equal(t, rec, result)
|
||||||
|
assert.Equal(t, "prefix", rec.DoPrefix)
|
||||||
|
assert.Equal(t, now.Unix(), rec.Timestamp.Unix())
|
||||||
|
assert.Equal(t, "operator", rec.Operator)
|
||||||
|
assert.Equal(t, []byte("extra"), rec.Extra)
|
||||||
|
assert.Equal(t, "log", rec.RCType)
|
||||||
|
}
|
||||||
54
api/model/record_timestamp_test.go
Normal file
54
api/model/record_timestamp_test.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRecord_TimestampNanosecondPrecision 验证 Record 的时间戳在 CBOR 序列化/反序列化后能保留纳秒精度
|
||||||
|
func TestRecord_TimestampNanosecondPrecision(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 创建一个包含纳秒精度的时间戳
|
||||||
|
timestamp := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||||
|
|
||||||
|
original := &model.Record{
|
||||||
|
ID: "rec-nanosecond-test",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerID: "producer-1",
|
||||||
|
Timestamp: timestamp,
|
||||||
|
Operator: "operator-1",
|
||||||
|
Extra: []byte("extra"),
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := original.CheckAndInit()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Original timestamp: %v", original.Timestamp)
|
||||||
|
t.Logf("Original nanoseconds: %d", original.Timestamp.Nanosecond())
|
||||||
|
|
||||||
|
// 序列化
|
||||||
|
data, err := original.MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// 反序列化
|
||||||
|
result := &model.Record{}
|
||||||
|
err = result.UnmarshalBinary(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Decoded timestamp: %v", result.Timestamp)
|
||||||
|
t.Logf("Decoded nanoseconds: %d", result.Timestamp.Nanosecond())
|
||||||
|
|
||||||
|
// 验证纳秒精度被完整保留
|
||||||
|
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||||
|
"时间戳的纳秒精度应该被完整保留")
|
||||||
|
assert.Equal(t, original.Timestamp.Nanosecond(), result.Timestamp.Nanosecond(),
|
||||||
|
"纳秒部分应该相等")
|
||||||
|
}
|
||||||
393
api/model/signature.go
Normal file
393
api/model/signature.go
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/crpt/go-crpt"
|
||||||
|
_ "github.com/crpt/go-crpt/sm2" // Import SM2 to register it
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPrivateKeyIsNil = errors.New("private key is nil")
|
||||||
|
ErrPublicAndKeysNotMatch = errors.New("public and private keys don't match")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ComputeSignature 计算SM2签名.
|
||||||
|
// 这是 SDK 默认的签名函数,使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - data: 待签名的原始数据
|
||||||
|
// - privateKeyDER: 私钥的DER编码字节数组
|
||||||
|
//
|
||||||
|
// 返回: 签名字节数组.
|
||||||
|
// 注意: go-crpt 库会自动使用 SM3 算法计算摘要并签名。
|
||||||
|
func ComputeSignature(data, privateKeyDER []byte) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Computing SM2 signature",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"privateKeyDERLength", len(privateKeyDER),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(privateKeyDER) == 0 {
|
||||||
|
log.Error("Private key is empty")
|
||||||
|
return nil, errors.New("private key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Error("Data to sign is empty")
|
||||||
|
return nil, errors.New("data to sign cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析DER格式的私钥
|
||||||
|
log.Debug("Parsing SM2 private key from DER format")
|
||||||
|
privateKey, err := crpt.PrivateKeyFromBytes(crpt.SM2, privateKeyDER)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to parse SM2 private key",
|
||||||
|
"error", err,
|
||||||
|
"keyLength", len(privateKeyDER),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to parse SM2 private key (key length: %d): %w", len(privateKeyDER), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if privateKey == nil {
|
||||||
|
log.Error("Parsed private key is nil")
|
||||||
|
return nil, ErrPrivateKeyIsNil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用SM2签名(ASN.1编码),go-crpt 库会自动使用 SM3 计算摘要
|
||||||
|
log.Debug("Signing raw data with SM2 using ASN.1 encoding (SM3 hash)")
|
||||||
|
signature, err := crpt.SignMessage(privateKey, data, rand.Reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign data with SM2",
|
||||||
|
"error", err,
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to sign data with SM2 (data length: %d): %w", len(data), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("SM2 signature computed successfully",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySignature 验证SM2签名.
|
||||||
|
// 这是 SDK 默认的验签函数,使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - data: 原始数据
|
||||||
|
// - publicKeyDER: 公钥的DER编码字节数组
|
||||||
|
// - signature: 签名字节数组
|
||||||
|
//
|
||||||
|
// 返回: 验证是否成功和可能的错误.
|
||||||
|
// 注意: go-crpt 库会自动使用 SM3 算法计算摘要并验证。
|
||||||
|
func VerifySignature(data, publicKeyDER, signature []byte) (bool, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying SM2 signature",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"publicKeyDERLength", len(publicKeyDER),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(publicKeyDER) == 0 {
|
||||||
|
log.Error("Public key is empty")
|
||||||
|
return false, errors.New("public key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Error("Data to verify is empty")
|
||||||
|
return false, errors.New("data to verify cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(signature) == 0 {
|
||||||
|
log.Error("Signature is empty")
|
||||||
|
return false, errors.New("signature cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析DER格式的公钥,复用ParseSM2PublicDER以避免代码重复
|
||||||
|
log.Debug("Parsing SM2 public key from DER format")
|
||||||
|
publicKey, err := ParseSM2PublicDER(publicKeyDER)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to parse SM2 public key",
|
||||||
|
"error", err,
|
||||||
|
"keyLength", len(publicKeyDER),
|
||||||
|
)
|
||||||
|
return false, fmt.Errorf("failed to parse SM2 public key (key length: %d): %w", len(publicKeyDER), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证签名(ASN.1编码),go-crpt 库会自动使用 SM3 计算摘要
|
||||||
|
log.Debug("Verifying signature with SM2 using ASN.1 encoding (SM3 hash)")
|
||||||
|
ok, err := crpt.VerifyMessage(publicKey, data, crpt.Signature(signature), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to verify SM2 signature",
|
||||||
|
"error", err,
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return false, fmt.Errorf("failed to verify signature: %w", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
log.Warn("SM2 signature verification failed",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return false, fmt.Errorf(
|
||||||
|
"signature verification failed (data length: %d, signature length: %d)",
|
||||||
|
len(data), len(signature),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
log.Debug("SM2 signature verified successfully",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSM2KeyPair 生成SM2密钥对.
|
||||||
|
// 这是 SDK 默认推荐的密钥生成方法。
|
||||||
|
//
|
||||||
|
// 返回新生成的密钥对,包含公钥和私钥.
|
||||||
|
// SM2 算法会在签名时自动使用 SM3 哈希。
|
||||||
|
func GenerateSM2KeyPair() (*SM2KeyPair, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Generating SM2 key pair")
|
||||||
|
pub, priv, err := crpt.GenerateKey(crpt.SM2, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to generate SM2 key pair", "error", err)
|
||||||
|
return nil, fmt.Errorf("failed to generate SM2 key pair: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if priv == nil {
|
||||||
|
log.Error("Generated private key is nil")
|
||||||
|
return nil, errors.New("generated private key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("SM2 key pair generated successfully")
|
||||||
|
return &SM2KeyPair{
|
||||||
|
Public: pub,
|
||||||
|
Private: priv,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SM2KeyPair SM2密钥对,包含公钥和私钥.
|
||||||
|
type SM2KeyPair struct {
|
||||||
|
Public crpt.PublicKey `json:"publicKey"`
|
||||||
|
Private crpt.PrivateKey `json:"privateKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalSM2PrivateDER 将私钥编码为DER格式.
|
||||||
|
// 将SM2私钥转换为DER格式的字节数组,用于存储或传输.
|
||||||
|
func MarshalSM2PrivateDER(priv crpt.PrivateKey) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Marshaling SM2 private key to DER format")
|
||||||
|
if priv == nil {
|
||||||
|
log.Error("Private key is nil")
|
||||||
|
return nil, errors.New("private key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
der := priv.Bytes()
|
||||||
|
log.Debug("SM2 private key marshaled to DER successfully",
|
||||||
|
"derLength", len(der),
|
||||||
|
)
|
||||||
|
return der, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseSM2PrivateDER 从DER格式解析私钥.
|
||||||
|
// 将DER格式的字节数组解析为SM2私钥对象.
|
||||||
|
func ParseSM2PrivateDER(der []byte) (crpt.PrivateKey, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Parsing SM2 private key from DER format",
|
||||||
|
"derLength", len(der),
|
||||||
|
)
|
||||||
|
if len(der) == 0 {
|
||||||
|
log.Error("DER encoded private key is empty")
|
||||||
|
return nil, errors.New("DER encoded private key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := crpt.PrivateKeyFromBytes(crpt.SM2, der)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to parse SM2 private key from DER",
|
||||||
|
"error", err,
|
||||||
|
"derLength", len(der),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to parse SM2 private key from DER (length: %d): %w", len(der), err)
|
||||||
|
}
|
||||||
|
log.Debug("SM2 private key parsed from DER successfully")
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalSM2PublicDER 将公钥编码为DER格式.
|
||||||
|
// 将SM2公钥转换为DER格式的字节数组,用于存储或传输.
|
||||||
|
func MarshalSM2PublicDER(pub crpt.PublicKey) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Marshaling SM2 public key to DER format")
|
||||||
|
if pub == nil {
|
||||||
|
log.Error("Public key is nil")
|
||||||
|
return nil, errors.New("public key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
der := pub.Bytes()
|
||||||
|
log.Debug("SM2 public key marshaled to DER successfully",
|
||||||
|
"derLength", len(der),
|
||||||
|
)
|
||||||
|
return der, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseSM2PublicDER 从DER格式解析公钥.
|
||||||
|
// 将DER格式的字节数组解析为SM2公钥对象.
|
||||||
|
// 返回解析后的公钥,如果数据不是有效的SM2公钥则返回错误.
|
||||||
|
func ParseSM2PublicDER(der []byte) (crpt.PublicKey, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Parsing SM2 public key from DER format",
|
||||||
|
"derLength", len(der),
|
||||||
|
)
|
||||||
|
if len(der) == 0 {
|
||||||
|
log.Error("DER encoded public key is empty")
|
||||||
|
return nil, errors.New("DER encoded public key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := crpt.PublicKeyFromBytes(crpt.SM2, der)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to parse SM2 public key",
|
||||||
|
"error", err,
|
||||||
|
"derLength", len(der),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("failed to parse SM2 public key (DER length: %d): %w", len(der), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("SM2 public key parsed from DER successfully")
|
||||||
|
return publicKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignMessage 使用密钥对签名消息(标准SM2签名).
|
||||||
|
// 使用标准SM2算法对消息进行签名,不包含用户标识(uid).
|
||||||
|
func (kp *SM2KeyPair) SignMessage(msg []byte) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Signing message with SM2 key pair",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
if kp.Private == nil {
|
||||||
|
log.Error("Private key is nil")
|
||||||
|
return nil, ErrPrivateKeyIsNil
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := crpt.SignMessage(kp.Private, msg, rand.Reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign message with SM2",
|
||||||
|
"error", err,
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Debug("Message signed successfully with SM2",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignGM 使用密钥对签名消息(国密标准SM2签名,带uid).
|
||||||
|
// 使用符合GB/T 32918标准的SM2算法对消息进行签名,包含用户标识(uid).
|
||||||
|
// uid用于Z值计算,通常为用户ID或标识符.
|
||||||
|
func (kp *SM2KeyPair) SignGM(msg, uid []byte) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Signing message with SM2 GM standard",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"uidLength", len(uid),
|
||||||
|
)
|
||||||
|
if kp.Private == nil {
|
||||||
|
log.Error("Private key is nil")
|
||||||
|
return nil, ErrPrivateKeyIsNil
|
||||||
|
}
|
||||||
|
|
||||||
|
// go-crpt uses SM3 hash internally, pass nil for standard signing
|
||||||
|
signature, err := crpt.SignMessage(kp.Private, msg, rand.Reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign message with SM2 GM standard",
|
||||||
|
"error", err,
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Debug("Message signed successfully with SM2 GM standard",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyMessage 使用公钥验证签名(标准SM2验签).
|
||||||
|
// 验证标准SM2签名,不使用用户标识(uid).
|
||||||
|
// 返回验证结果和可能的错误.如果验证失败但没有错误发生,返回(false, nil).
|
||||||
|
func (kp *SM2KeyPair) VerifyMessage(msg, sig []byte) (bool, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying message signature with SM2",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"signatureLength", len(sig),
|
||||||
|
)
|
||||||
|
if kp.Public == nil {
|
||||||
|
log.Error("Public key is nil")
|
||||||
|
return false, ErrPublicAndKeysNotMatch
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := crpt.VerifyMessage(kp.Public, msg, crpt.Signature(sig), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error verifying message with SM2",
|
||||||
|
"error", err,
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Debug("Message signature verified successfully with SM2",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warn("Message signature verification failed with SM2",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"signatureLength", len(sig),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyGM 使用公钥验证签名(国密标准SM2验签,带uid).
|
||||||
|
// 验证符合GB/T 32918标准的SM2签名,使用用户标识(uid).
|
||||||
|
// 返回验证结果和可能的错误.如果验证失败但没有错误发生,返回(false, nil).
|
||||||
|
func (kp *SM2KeyPair) VerifyGM(msg, sig, uid []byte) (bool, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying message signature with SM2 GM standard",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"signatureLength", len(sig),
|
||||||
|
"uidLength", len(uid),
|
||||||
|
)
|
||||||
|
if kp.Public == nil {
|
||||||
|
log.Error("Public key is nil")
|
||||||
|
return false, ErrPublicAndKeysNotMatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// go-crpt uses SM3 hash internally
|
||||||
|
ok, err := crpt.VerifyMessage(kp.Public, msg, crpt.Signature(sig), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error verifying message with SM2 GM standard",
|
||||||
|
"error", err,
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Debug("Message signature verified successfully with SM2 GM standard",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warn("Message signature verification failed with SM2 GM standard",
|
||||||
|
"messageLength", len(msg),
|
||||||
|
"signatureLength", len(sig),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
253
api/model/signature_test.go
Normal file
253
api/model/signature_test.go
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestComputeSignature_EmptyPrivateKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.ComputeSignature([]byte("data"), nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "private key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeSignature_EmptyData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
privateKey := []byte("invalid-key")
|
||||||
|
_, err := model.ComputeSignature(nil, privateKey)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data to sign cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeSignature_InvalidKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.ComputeSignature([]byte("data"), []byte("invalid-key"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse SM2 private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySignature_EmptyPublicKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.VerifySignature([]byte("data"), nil, []byte("signature"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "public key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySignature_EmptyData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
publicKey := []byte("invalid-key")
|
||||||
|
_, err := model.VerifySignature(nil, publicKey, []byte("signature"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "data to verify cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySignature_InvalidKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
publicKey := []byte("invalid-key")
|
||||||
|
valid, err := model.VerifySignature([]byte("data"), publicKey, []byte("signature"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.False(t, valid)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse SM2 public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSM2KeyPair(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, keyPair)
|
||||||
|
assert.NotNil(t, keyPair.Public)
|
||||||
|
assert.NotNil(t, keyPair.Private)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalSM2PrivateDER_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.MarshalSM2PrivateDER(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "private key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalSM2PrivateDER(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
der, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, der)
|
||||||
|
assert.NotEmpty(t, der)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSM2PrivateDER_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.ParseSM2PrivateDER(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "DER encoded private key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSM2PrivateDER_Invalid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.ParseSM2PrivateDER([]byte("invalid-der"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse SM2 private key from DER")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSM2PrivateDER_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
der, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedKey, err := model.ParseSM2PrivateDER(der)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalSM2PublicDER_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.MarshalSM2PublicDER(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "public key is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalSM2PublicDER(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
der, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, der)
|
||||||
|
assert.NotEmpty(t, der)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSM2PublicDER_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.ParseSM2PublicDER(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "DER encoded public key cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSM2PublicDER_Invalid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := model.ParseSM2PublicDER([]byte("invalid-der"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse SM2 public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSM2PublicDER_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
der, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedKey, err := model.ParseSM2PublicDER(der)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSM2SignAndVerify_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Generate key pair
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Marshal keys
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign data
|
||||||
|
data := []byte("test data")
|
||||||
|
signature, err := model.ComputeSignature(data, privateKeyDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, signature)
|
||||||
|
assert.NotEmpty(t, signature)
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
valid, err := model.VerifySignature(data, publicKeyDER, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSM2SignAndVerify_WrongData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Generate key pair
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Marshal keys
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign data
|
||||||
|
data := []byte("test data")
|
||||||
|
signature, err := model.ComputeSignature(data, privateKeyDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify with wrong data
|
||||||
|
wrongData := []byte("wrong data")
|
||||||
|
valid, err := model.VerifySignature(wrongData, publicKeyDER, signature)
|
||||||
|
// Verification should return error
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.False(t, valid)
|
||||||
|
assert.Contains(t, err.Error(), "signature verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSM2SignAndVerify_WrongSignature(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Generate key pair
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Marshal keys
|
||||||
|
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign data
|
||||||
|
data := []byte("test data")
|
||||||
|
_, err = model.ComputeSignature(data, privateKeyDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify with wrong signature
|
||||||
|
wrongSignature := []byte("wrong signature")
|
||||||
|
valid, err := model.VerifySignature(data, publicKeyDER, wrongSignature)
|
||||||
|
require.Error(t, err) // Should fail verification
|
||||||
|
assert.False(t, valid)
|
||||||
|
}
|
||||||
155
api/model/signer.go
Normal file
155
api/model/signer.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Signer 签名器接口,用于抽象不同的签名算法实现。
|
||||||
|
// 实现了此接口的类型可以提供签名和验签功能。
|
||||||
|
//
|
||||||
|
// SDK 默认使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||||
|
// 可通过 SetGlobalCryptoConfig 切换到其他算法(如 Ed25519)。
|
||||||
|
type Signer interface {
|
||||||
|
// Sign 对数据进行签名。
|
||||||
|
// 参数:
|
||||||
|
// - data: 待签名的原始数据
|
||||||
|
// 返回: 签名字节数组和可能的错误
|
||||||
|
Sign(data []byte) ([]byte, error)
|
||||||
|
|
||||||
|
// Verify 验证签名。
|
||||||
|
// 参数:
|
||||||
|
// - data: 原始数据
|
||||||
|
// - signature: 签名字节数组
|
||||||
|
//
|
||||||
|
// 返回: 验证是否成功和可能的错误
|
||||||
|
Verify(data, signature []byte) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SM2Signer SM2签名器实现。
|
||||||
|
// 使用SM2算法进行签名和验签(内部自动使用 SM3 哈希)。
|
||||||
|
//
|
||||||
|
// 这是 SDK 的默认签名算法。如需使用其他算法,请使用 ConfigSigner。
|
||||||
|
type SM2Signer struct {
|
||||||
|
privateKey []byte // 私钥(DER编码格式)
|
||||||
|
publicKey []byte // 公钥(DER编码格式)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSM2Signer 创建新的SM2签名器。
|
||||||
|
// 这是 SDK 默认推荐的签名器,使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - privateKey: 私钥(DER编码格式),用于签名
|
||||||
|
// - publicKey: 公钥(DER编码格式),用于验签
|
||||||
|
//
|
||||||
|
// 示例:
|
||||||
|
//
|
||||||
|
// keyPair, _ := model.GenerateSM2KeyPair()
|
||||||
|
// privateKeyDER, _ := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||||
|
// publicKeyDER, _ := model.MarshalSM2PublicDER(keyPair.Public)
|
||||||
|
// signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||||
|
func NewSM2Signer(privateKey, publicKey []byte) *SM2Signer {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new SM2 signer (default algorithm, uses SM3 hash)",
|
||||||
|
"privateKeyLength", len(privateKey),
|
||||||
|
"publicKeyLength", len(publicKey),
|
||||||
|
)
|
||||||
|
return &SM2Signer{
|
||||||
|
privateKey: privateKey,
|
||||||
|
publicKey: publicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign 使用SM2私钥对数据进行签名(内部使用 SM3 哈希)。
|
||||||
|
func (s *SM2Signer) Sign(data []byte) ([]byte, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Signing data with SM2 (using SM3 hash)",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"privateKeyLength", len(s.privateKey),
|
||||||
|
)
|
||||||
|
signature, err := ComputeSignature(data, s.privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to sign data with SM2",
|
||||||
|
"error", err,
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Debug("Data signed successfully with SM2",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 使用SM2公钥验证签名(内部使用 SM3 哈希)。
|
||||||
|
// 注意: go-crpt 库会自动使用 SM3 算法计算摘要并验证。
|
||||||
|
// 返回: 验证是否成功和可能的错误.
|
||||||
|
func (s *SM2Signer) Verify(data, signature []byte) (bool, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Verifying signature with SM2",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
"publicKeyLength", len(s.publicKey),
|
||||||
|
)
|
||||||
|
valid, err := VerifySignature(data, s.publicKey, signature)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to verify signature with SM2",
|
||||||
|
"error", err,
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
log.Debug("Signature verified successfully with SM2",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warn("Signature verification failed with SM2",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return valid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NopSigner 空操作签名器实现。
|
||||||
|
// 对原hash不做任何操作,直接返回原数据。
|
||||||
|
// 适用于不需要实际签名操作的场景,如测试或某些特殊用途。
|
||||||
|
type NopSigner struct{}
|
||||||
|
|
||||||
|
// NewNopSigner 创建新的空操作签名器。
|
||||||
|
func NewNopSigner() *NopSigner {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("Creating new NopSigner")
|
||||||
|
return &NopSigner{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign 直接返回原数据,不做任何签名操作。
|
||||||
|
func (n *NopSigner) Sign(_ []byte) ([]byte, error) {
|
||||||
|
|
||||||
|
return ([]byte)("test"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 验证签名是否等于原数据。
|
||||||
|
func (n *NopSigner) Verify(data, signature []byte) (bool, error) {
|
||||||
|
log := logger.GetGlobalLogger()
|
||||||
|
log.Debug("NopSigner: verifying signature",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
valid := bytes.Equal(data, signature)
|
||||||
|
if valid {
|
||||||
|
log.Debug("NopSigner: signature verified successfully",
|
||||||
|
"dataLength", len(data),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warn("NopSigner: signature verification failed",
|
||||||
|
"dataLength", len(data),
|
||||||
|
"signatureLength", len(signature),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return valid, nil
|
||||||
|
}
|
||||||
135
api/model/signer_test.go
Normal file
135
api/model/signer_test.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSM2Signer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
privateKey := []byte("test-private-key")
|
||||||
|
publicKey := []byte("test-public-key")
|
||||||
|
|
||||||
|
signer := model.NewSM2Signer(privateKey, publicKey)
|
||||||
|
assert.NotNil(t, signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewNopSigner(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
assert.NotNil(t, signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopSigner_Sign(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
data := []byte("test data")
|
||||||
|
|
||||||
|
result, err := signer.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, data, result)
|
||||||
|
assert.NotSame(t, &data[0], &result[0]) // Should be a copy
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopSigner_Sign_Empty(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
data := []byte{}
|
||||||
|
|
||||||
|
result, err := signer.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, data, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopSigner_Verify_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
data := []byte("test data")
|
||||||
|
signature := []byte("test data") // Same as data
|
||||||
|
|
||||||
|
valid, err := signer.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopSigner_Verify_Failure(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
data := []byte("test data")
|
||||||
|
signature := []byte("different data")
|
||||||
|
|
||||||
|
valid, err := signer.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopSigner_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
data := []byte("test data")
|
||||||
|
|
||||||
|
signature, err := signer.Sign(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
valid, err := signer.Verify(data, signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopSigner_Verify_DifferentLengths(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
signer := model.NewNopSigner()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
signature []byte
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "same data",
|
||||||
|
data: []byte("test"),
|
||||||
|
signature: []byte("test"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different data",
|
||||||
|
data: []byte("test"),
|
||||||
|
signature: []byte("test2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different lengths",
|
||||||
|
data: []byte("test"),
|
||||||
|
signature: []byte("test1"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
data: []byte{},
|
||||||
|
signature: []byte{},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
valid, err := signer.Verify(tt.data, tt.signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, valid)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
65
api/model/sm2_consistency_test.go
Normal file
65
api/model/sm2_consistency_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSM2HashConsistency 验证SM2加签和验签的一致性
|
||||||
|
// 关键发现:SM2库内部会处理hash,但加签和验签必须使用相同的数据类型.
|
||||||
|
func TestSM2HashConsistency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 测试数据
|
||||||
|
originalData := []byte("test data for consistency check")
|
||||||
|
|
||||||
|
t.Logf("=== 测试1:加签和验签都使用原始数据(当前实现)===")
|
||||||
|
// 1. 加签:使用原始数据
|
||||||
|
signature1, err := keyPair.SignMessage(originalData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 2. 验签:使用原始数据
|
||||||
|
valid1, err := keyPair.VerifyMessage(originalData, signature1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("加签(原始数据) + 验签(原始数据): %v", valid1)
|
||||||
|
assert.True(t, valid1, "加签和验签都使用原始数据应该成功")
|
||||||
|
|
||||||
|
t.Logf("\n=== 测试2:加签和验签都使用hash值 ===")
|
||||||
|
// 3. 加签:使用hash值
|
||||||
|
hashBytes := sha256.Sum256(originalData)
|
||||||
|
signature2, err := keyPair.SignMessage(hashBytes[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 4. 验签:使用hash值
|
||||||
|
valid2, err := keyPair.VerifyMessage(hashBytes[:], signature2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("加签(hash值) + 验签(hash值): %v", valid2)
|
||||||
|
assert.True(t, valid2, "加签和验签都使用hash值应该成功")
|
||||||
|
|
||||||
|
t.Logf("\n=== 测试3:不一致的情况(应该失败)===")
|
||||||
|
// 5. 加签使用原始数据,验签使用hash值 - 应该失败
|
||||||
|
valid3, err := keyPair.VerifyMessage(hashBytes[:], signature1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("加签(原始数据) + 验签(hash值): %v", valid3)
|
||||||
|
assert.False(t, valid3, "加签和验签使用不同类型数据应该失败")
|
||||||
|
|
||||||
|
// 6. 加签使用hash值,验签使用原始数据 - 应该失败
|
||||||
|
valid4, err := keyPair.VerifyMessage(originalData, signature2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("加签(hash值) + 验签(原始数据): %v", valid4)
|
||||||
|
assert.False(t, valid4, "加签和验签使用不同类型数据应该失败")
|
||||||
|
|
||||||
|
t.Logf("\n=== 结论 ===")
|
||||||
|
t.Logf("✓ SM2库内部会处理hash")
|
||||||
|
t.Logf("✓ 加签和验签必须使用相同的数据类型(都是原始数据,或都是hash值)")
|
||||||
|
t.Logf("✓ 当前实现(加签和验签都使用原始数据)是正确的")
|
||||||
|
}
|
||||||
82
api/model/sm2_hash_test.go
Normal file
82
api/model/sm2_hash_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSM2RequiresHash 测试SM2是否要求预先hash数据
|
||||||
|
// 根据文档,SM2.SignASN1期望接收hash值,而不是原始数据
|
||||||
|
// 但文档也提到:如果opts是*SM2SignerOption且ForceGMSign为true,则hash会被视为原始消息.
|
||||||
|
func TestSM2RequiresHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 生成SM2密钥对
|
||||||
|
keyPair, err := model.GenerateSM2KeyPair()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 测试数据
|
||||||
|
originalData := []byte("test data for SM2 signing")
|
||||||
|
|
||||||
|
// 1. 直接对原始数据签名(当前实现的方式)
|
||||||
|
// go-crpt 库会自动使用 SM3 计算摘要
|
||||||
|
signature1, err := keyPair.SignMessage(originalData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signature1)
|
||||||
|
|
||||||
|
// 2. 验证签名(使用原始数据)
|
||||||
|
valid1, err := keyPair.VerifyMessage(originalData, signature1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("直接使用原始数据签名和验证结果: %v", valid1)
|
||||||
|
assert.True(t, valid1, "当前实现:直接对原始数据签名和验证应该成功")
|
||||||
|
|
||||||
|
// 3. 先hash再签名(文档推荐的方式)
|
||||||
|
hashBytesReal := sha256.Sum256(originalData)
|
||||||
|
|
||||||
|
signature2, err := keyPair.SignMessage(hashBytesReal[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signature2)
|
||||||
|
|
||||||
|
// 4. 验证签名(使用hash值)
|
||||||
|
valid2, err := keyPair.VerifyMessage(hashBytesReal[:], signature2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("先hash再签名和验证结果: %v", valid2)
|
||||||
|
assert.True(t, valid2, "先hash再签名和验证应该成功")
|
||||||
|
|
||||||
|
// 5. 交叉验证:用原始数据验证hash后的签名
|
||||||
|
valid3, err := keyPair.VerifyMessage(originalData, signature2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("用原始数据验证hash后的签名结果: %v", valid3)
|
||||||
|
|
||||||
|
// 6. 交叉验证:用hash值验证原始数据的签名
|
||||||
|
valid4, err := keyPair.VerifyMessage(hashBytesReal[:], signature1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("用hash值验证原始数据的签名结果: %v", valid4)
|
||||||
|
|
||||||
|
// 结论:
|
||||||
|
// - 如果valid1=true且valid4=false,说明SM2内部可能处理了hash,或者有某种兼容性
|
||||||
|
// - 如果valid1=true且valid4=true,说明SM2可能接受原始数据(不符合文档)
|
||||||
|
// - 如果valid1=false,说明SM2确实需要hash值
|
||||||
|
|
||||||
|
t.Logf("\n结论分析:")
|
||||||
|
t.Logf("- 直接对原始数据签名和验证: %v", valid1)
|
||||||
|
t.Logf("- 先hash再签名和验证: %v", valid2)
|
||||||
|
t.Logf("- 交叉验证1(原始数据 vs hash签名): %v", valid3)
|
||||||
|
t.Logf("- 交叉验证2(hash数据 vs 原始签名): %v", valid4)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case valid1 && !valid4:
|
||||||
|
t.Logf("✓ SM2库可能内部处理了hash,或者有兼容性机制")
|
||||||
|
t.Logf("✓ 当前实现(直接使用原始数据)可能是可行的")
|
||||||
|
case valid1 && valid4:
|
||||||
|
t.Logf("⚠ SM2库可能接受原始数据,与文档不符")
|
||||||
|
t.Logf("⚠ 但当前实现可以工作")
|
||||||
|
default:
|
||||||
|
t.Logf("✗ SM2确实需要hash值,当前实现可能有问题")
|
||||||
|
}
|
||||||
|
}
|
||||||
13
api/model/trustlog.go
Normal file
13
api/model/trustlog.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "encoding"
|
||||||
|
|
||||||
|
// Trustlog 接口定义了信任日志的基本操作。
|
||||||
|
// 实现了此接口的类型可以进行序列化、反序列化、哈希计算和提供生产者ID。
|
||||||
|
type Trustlog interface {
|
||||||
|
Hashable
|
||||||
|
encoding.BinaryMarshaler
|
||||||
|
encoding.BinaryUnmarshaler
|
||||||
|
GetProducerID() string
|
||||||
|
Key() string
|
||||||
|
}
|
||||||
32
api/model/validation.go
Normal file
32
api/model/validation.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
// Validation status codes.
|
||||||
|
const (
|
||||||
|
ValidationCodeProcessing = 100 // 处理中
|
||||||
|
ValidationCodeCompleted = 200 // 完成
|
||||||
|
ValidationCodeFailed = 500 // 失败
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidationResult 包装取证的流式响应结果.
|
||||||
|
type ValidationResult struct {
|
||||||
|
Code int32 // 状态码(100处理中,200完成,500失败)
|
||||||
|
Msg string // 消息描述
|
||||||
|
Progress string // 当前进度(比如 "50%")
|
||||||
|
Data *Operation // 最终完成时返回的操作数据,过程中可为空
|
||||||
|
Proof *Proof // 取证证明(仅在完成时返回)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProcessing 判断是否正在处理中.
|
||||||
|
func (v *ValidationResult) IsProcessing() bool {
|
||||||
|
return v.Code == ValidationCodeProcessing
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCompleted 判断是否已完成.
|
||||||
|
func (v *ValidationResult) IsCompleted() bool {
|
||||||
|
return v.Code == ValidationCodeCompleted
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFailed 判断是否失败.
|
||||||
|
func (v *ValidationResult) IsFailed() bool {
|
||||||
|
return v.Code >= ValidationCodeFailed
|
||||||
|
}
|
||||||
238
api/model/validation_test.go
Normal file
238
api/model/validation_test.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidationResult_IsProcessing(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code int32
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "processing code",
|
||||||
|
code: model.ValidationCodeProcessing,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed code",
|
||||||
|
code: model.ValidationCodeCompleted,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed code",
|
||||||
|
code: model.ValidationCodeFailed,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other code",
|
||||||
|
code: 99,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
v := &model.ValidationResult{Code: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, v.IsProcessing())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidationResult_IsCompleted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code int32
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "processing code",
|
||||||
|
code: model.ValidationCodeProcessing,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed code",
|
||||||
|
code: model.ValidationCodeCompleted,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed code",
|
||||||
|
code: model.ValidationCodeFailed,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other code",
|
||||||
|
code: 99,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
v := &model.ValidationResult{Code: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, v.IsCompleted())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidationResult_IsFailed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code int32
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "processing code",
|
||||||
|
code: model.ValidationCodeProcessing,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed code",
|
||||||
|
code: model.ValidationCodeCompleted,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed code",
|
||||||
|
code: model.ValidationCodeFailed,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "code greater than failed",
|
||||||
|
code: 501,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "code less than failed",
|
||||||
|
code: 499,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
v := &model.ValidationResult{Code: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, v.IsFailed())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordValidationResult_IsProcessing(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code int32
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "processing code",
|
||||||
|
code: model.ValidationCodeProcessing,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed code",
|
||||||
|
code: model.ValidationCodeCompleted,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed code",
|
||||||
|
code: model.ValidationCodeFailed,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := &model.RecordValidationResult{Code: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, r.IsProcessing())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordValidationResult_IsCompleted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code int32
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "processing code",
|
||||||
|
code: model.ValidationCodeProcessing,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed code",
|
||||||
|
code: model.ValidationCodeCompleted,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed code",
|
||||||
|
code: model.ValidationCodeFailed,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := &model.RecordValidationResult{Code: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, r.IsCompleted())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordValidationResult_IsFailed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code int32
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "processing code",
|
||||||
|
code: model.ValidationCodeProcessing,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completed code",
|
||||||
|
code: model.ValidationCodeCompleted,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed code",
|
||||||
|
code: model.ValidationCodeFailed,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "code greater than failed",
|
||||||
|
code: 501,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := &model.RecordValidationResult{Code: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, r.IsFailed())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
441
api/queryclient/client.go
Normal file
441
api/queryclient/client.go
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
package queryclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/grpcclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defaultChannelBuffer 是channel的默认缓冲区大小.
|
||||||
|
defaultChannelBuffer = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// serverClients 封装单个服务器的两种服务客户端.
|
||||||
|
type serverClients struct {
|
||||||
|
opClient pb.OperationValidationServiceClient
|
||||||
|
recClient pb.RecordValidationServiceClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client 查询客户端,包装gRPC客户端提供操作和记录的查询及验证功能.
|
||||||
|
type Client struct {
|
||||||
|
connLB *grpcclient.LoadBalancer[*serverClients]
|
||||||
|
logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConfig 客户端配置.
|
||||||
|
type ClientConfig = grpcclient.Config
|
||||||
|
|
||||||
|
// NewClient 创建新的查询客户端.
|
||||||
|
func NewClient(config ClientConfig, logger logger.Logger) (*Client, error) {
|
||||||
|
// 获取服务器地址列表
|
||||||
|
addrs, err := config.GetAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建连接负载均衡器,每个连接同时创建两种服务的客户端
|
||||||
|
connLB, err := grpcclient.NewLoadBalancer(
|
||||||
|
addrs,
|
||||||
|
config.DialOptions,
|
||||||
|
func(conn grpc.ClientConnInterface) *serverClients {
|
||||||
|
return &serverClients{
|
||||||
|
opClient: pb.NewOperationValidationServiceClient(conn),
|
||||||
|
recClient: pb.NewRecordValidationServiceClient(conn),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Query client initialized", "serverCount", len(addrs))
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
connLB: connLB,
|
||||||
|
logger: logger,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListOperationsRequest 列表查询请求参数.
|
||||||
|
type ListOperationsRequest struct {
|
||||||
|
// 分页参数
|
||||||
|
PageSize uint64 // 页面大小
|
||||||
|
PreTime time.Time // 上一页最后一个时间(用于游标分页)
|
||||||
|
|
||||||
|
// 可选过滤条件
|
||||||
|
Timestamp *time.Time // 操作时间戳
|
||||||
|
OpSource model.Source // 操作来源
|
||||||
|
OpType model.Type // 操作类型
|
||||||
|
DoPrefix string // 数据前缀
|
||||||
|
DoRepository string // 数据仓库
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListOperationsResponse 列表查询响应.
|
||||||
|
type ListOperationsResponse struct {
|
||||||
|
Count int64 // 数据总量
|
||||||
|
Data []*model.Operation // 操作列表
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListOperations 查询操作列表.
|
||||||
|
func (c *Client) ListOperations(ctx context.Context, req ListOperationsRequest) (*ListOperationsResponse, error) {
|
||||||
|
c.logger.DebugContext(ctx, "Querying operations list", "pageSize", req.PageSize)
|
||||||
|
|
||||||
|
// 使用负载均衡器获取客户端
|
||||||
|
clients := c.connLB.Next()
|
||||||
|
client := clients.opClient
|
||||||
|
|
||||||
|
// 构建protobuf请求
|
||||||
|
pbReq := &pb.ListOperationReq{
|
||||||
|
PageSize: req.PageSize,
|
||||||
|
OpSource: string(req.OpSource),
|
||||||
|
OpType: string(req.OpType),
|
||||||
|
DoPrefix: req.DoPrefix,
|
||||||
|
DoRepository: req.DoRepository,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置可选参数
|
||||||
|
if !req.PreTime.IsZero() {
|
||||||
|
pbReq.PreTime = timestamppb.New(req.PreTime)
|
||||||
|
}
|
||||||
|
if req.Timestamp != nil {
|
||||||
|
pbReq.Timestamp = timestamppb.New(*req.Timestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用gRPC
|
||||||
|
pbRes, err := client.ListOperations(ctx, pbReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list operations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换响应
|
||||||
|
operations := make([]*model.Operation, 0, len(pbRes.GetData()))
|
||||||
|
for _, pbOp := range pbRes.GetData() {
|
||||||
|
op, convertErr := model.FromProtobuf(pbOp)
|
||||||
|
if convertErr != nil {
|
||||||
|
c.logger.ErrorContext(ctx, "Failed to convert operation", "error", convertErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
operations = append(operations, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ListOperationsResponse{
|
||||||
|
Count: pbRes.GetCount(),
|
||||||
|
Data: operations,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationRequest 取证验证请求参数.
|
||||||
|
type ValidationRequest struct {
|
||||||
|
Time time.Time // 操作时间戳
|
||||||
|
OpID string // 操作唯一标识符
|
||||||
|
OpType string // 操作类型
|
||||||
|
DoRepository string // 数据仓库标识
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateOperation 执行操作取证验证,返回流式结果通道
|
||||||
|
// 该方法会启动一个goroutine接收流式响应,通过返回的channel发送结果
|
||||||
|
// 当流结束或发生错误时,channel会被关闭.
|
||||||
|
//
|
||||||
|
//nolint:dupl // 与 ValidateRecord 有相似逻辑,但处理不同的数据类型和 gRPC 服务
|
||||||
|
func (c *Client) ValidateOperation(ctx context.Context, req ValidationRequest) (<-chan *model.ValidationResult, error) {
|
||||||
|
c.logger.InfoContext(ctx, "Starting validation for operation", "opID", req.OpID)
|
||||||
|
|
||||||
|
// 使用负载均衡器获取客户端
|
||||||
|
clients := c.connLB.Next()
|
||||||
|
client := clients.opClient
|
||||||
|
|
||||||
|
// 构建protobuf请求
|
||||||
|
pbReq := &pb.ValidationReq{
|
||||||
|
Time: timestamppb.New(req.Time),
|
||||||
|
OpId: req.OpID,
|
||||||
|
OpType: req.OpType,
|
||||||
|
DoRepository: req.DoRepository,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用gRPC流式方法
|
||||||
|
stream, err := client.ValidateOperation(ctx, pbReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start validation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建结果通道
|
||||||
|
resultChan := make(chan *model.ValidationResult, defaultChannelBuffer)
|
||||||
|
|
||||||
|
// 启动goroutine接收流式响应
|
||||||
|
go func() {
|
||||||
|
defer close(resultChan)
|
||||||
|
|
||||||
|
for {
|
||||||
|
pbRes, recvErr := stream.Recv()
|
||||||
|
if recvErr != nil {
|
||||||
|
if errors.Is(recvErr, io.EOF) {
|
||||||
|
// 流正常结束
|
||||||
|
c.logger.DebugContext(ctx, "Validation stream completed", "opID", req.OpID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 发生错误
|
||||||
|
c.logger.ErrorContext(ctx, "Error receiving validation result", "error", recvErr)
|
||||||
|
// 发送错误结果
|
||||||
|
resultChan <- &model.ValidationResult{
|
||||||
|
Code: model.ValidationCodeFailed,
|
||||||
|
Msg: fmt.Sprintf("Stream error: %v", recvErr),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换并发送结果
|
||||||
|
result, convertErr := model.FromProtobufValidationResult(pbRes)
|
||||||
|
if convertErr != nil {
|
||||||
|
c.logger.ErrorContext(ctx, "Failed to convert validation result", "error", convertErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case resultChan <- result:
|
||||||
|
c.logger.DebugContext(ctx, "Sent validation result", "code", result.Code, "progress", result.Progress)
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.logger.InfoContext(ctx, "Context cancelled, stopping validation stream")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return resultChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateOperationSync 同步执行操作取证验证,阻塞直到获得最终结果
|
||||||
|
// 该方法会处理所有中间进度,只返回最终的完成结果.
|
||||||
|
func (c *Client) ValidateOperationSync(
|
||||||
|
ctx context.Context,
|
||||||
|
req ValidationRequest,
|
||||||
|
progressCallback func(*model.ValidationResult),
|
||||||
|
) (*model.ValidationResult, error) {
|
||||||
|
resultChan, err := c.ValidateOperation(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var finalResult *model.ValidationResult
|
||||||
|
for result := range resultChan {
|
||||||
|
if result.IsCompleted() || result.IsFailed() {
|
||||||
|
finalResult = result
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果提供了进度回调,则调用
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalResult == nil {
|
||||||
|
return nil, errors.New("validation completed without final result")
|
||||||
|
}
|
||||||
|
|
||||||
|
return finalResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRecordsRequest 列表查询请求参数.
|
||||||
|
type ListRecordsRequest struct {
|
||||||
|
// 分页参数
|
||||||
|
PageSize uint64 // 页面大小
|
||||||
|
PreTime time.Time // 上一页最后一个时间(用于游标分页)
|
||||||
|
|
||||||
|
// 可选过滤条件
|
||||||
|
DoPrefix string // 数据前缀
|
||||||
|
RCType string // 记录类型
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRecordsResponse 列表查询响应.
|
||||||
|
type ListRecordsResponse struct {
|
||||||
|
Count int64 // 数据总量
|
||||||
|
Data []*model.Record // 记录列表
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRecords 查询记录列表.
|
||||||
|
func (c *Client) ListRecords(ctx context.Context, req ListRecordsRequest) (*ListRecordsResponse, error) {
|
||||||
|
c.logger.DebugContext(ctx, "Querying records list", "pageSize", req.PageSize)
|
||||||
|
|
||||||
|
// 使用负载均衡器获取客户端
|
||||||
|
clients := c.connLB.Next()
|
||||||
|
client := clients.recClient
|
||||||
|
|
||||||
|
// 构建protobuf请求
|
||||||
|
pbReq := &pb.ListRecordReq{
|
||||||
|
PageSize: req.PageSize,
|
||||||
|
DoPrefix: req.DoPrefix,
|
||||||
|
RcType: req.RCType,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置可选参数
|
||||||
|
if !req.PreTime.IsZero() {
|
||||||
|
pbReq.PreTime = timestamppb.New(req.PreTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用gRPC
|
||||||
|
pbRes, err := client.ListRecords(ctx, pbReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换响应
|
||||||
|
records := make([]*model.Record, 0, len(pbRes.GetData()))
|
||||||
|
for _, pbRec := range pbRes.GetData() {
|
||||||
|
rec, convertErr := model.RecordFromProtobuf(pbRec)
|
||||||
|
if convertErr != nil {
|
||||||
|
c.logger.ErrorContext(ctx, "Failed to convert record", "error", convertErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
records = append(records, rec)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ListRecordsResponse{
|
||||||
|
Count: pbRes.GetCount(),
|
||||||
|
Data: records,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordValidationRequest 记录验证请求参数.
|
||||||
|
type RecordValidationRequest struct {
|
||||||
|
Timestamp time.Time // 记录时间戳
|
||||||
|
RecordID string // 要验证的记录ID
|
||||||
|
DoPrefix string // 数据前缀(可选)
|
||||||
|
RCType string // 记录类型
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRecord 执行记录验证,返回流式结果通道
|
||||||
|
// 该方法会启动一个goroutine接收流式响应,通过返回的channel发送结果
|
||||||
|
// 当流结束或发生错误时,channel会被关闭.
|
||||||
|
//
|
||||||
|
//nolint:dupl // 与 ValidateOperation 有相似逻辑,但处理不同的数据类型和 gRPC 服务
|
||||||
|
func (c *Client) ValidateRecord(
|
||||||
|
ctx context.Context,
|
||||||
|
req RecordValidationRequest,
|
||||||
|
) (<-chan *model.RecordValidationResult, error) {
|
||||||
|
c.logger.InfoContext(ctx, "Starting validation for record", "recordID", req.RecordID)
|
||||||
|
|
||||||
|
// 使用负载均衡器获取客户端
|
||||||
|
clients := c.connLB.Next()
|
||||||
|
client := clients.recClient
|
||||||
|
|
||||||
|
// 构建protobuf请求
|
||||||
|
pbReq := &pb.RecordValidationReq{
|
||||||
|
Timestamp: timestamppb.New(req.Timestamp),
|
||||||
|
RecordId: req.RecordID,
|
||||||
|
DoPrefix: req.DoPrefix,
|
||||||
|
RcType: req.RCType,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用gRPC流式方法
|
||||||
|
stream, err := client.ValidateRecord(ctx, pbReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start validation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建结果通道
|
||||||
|
resultChan := make(chan *model.RecordValidationResult, defaultChannelBuffer)
|
||||||
|
|
||||||
|
// 启动goroutine接收流式响应
|
||||||
|
go func() {
|
||||||
|
defer close(resultChan)
|
||||||
|
|
||||||
|
for {
|
||||||
|
pbRes, recvErr := stream.Recv()
|
||||||
|
if recvErr != nil {
|
||||||
|
if errors.Is(recvErr, io.EOF) {
|
||||||
|
// 流正常结束
|
||||||
|
c.logger.DebugContext(ctx, "Validation stream completed", "recordID", req.RecordID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 发生错误
|
||||||
|
c.logger.ErrorContext(ctx, "Error receiving validation result", "error", recvErr)
|
||||||
|
// 发送错误结果
|
||||||
|
resultChan <- &model.RecordValidationResult{
|
||||||
|
Code: model.ValidationCodeFailed,
|
||||||
|
Msg: fmt.Sprintf("Stream error: %v", recvErr),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换并发送结果
|
||||||
|
result, convertErr := model.RecordFromProtobufValidationResult(pbRes)
|
||||||
|
if convertErr != nil {
|
||||||
|
c.logger.ErrorContext(ctx, "Failed to convert validation result", "error", convertErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case resultChan <- result:
|
||||||
|
c.logger.DebugContext(ctx, "Sent validation result", "code", result.Code, "progress", result.Progress)
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.logger.InfoContext(ctx, "Context cancelled, stopping validation stream")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return resultChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRecordSync 同步执行记录验证,阻塞直到获得最终结果
|
||||||
|
// 该方法会处理所有中间进度,只返回最终的完成结果.
|
||||||
|
func (c *Client) ValidateRecordSync(
|
||||||
|
ctx context.Context,
|
||||||
|
req RecordValidationRequest,
|
||||||
|
progressCallback func(*model.RecordValidationResult),
|
||||||
|
) (*model.RecordValidationResult, error) {
|
||||||
|
resultChan, err := c.ValidateRecord(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var finalResult *model.RecordValidationResult
|
||||||
|
for result := range resultChan {
|
||||||
|
if result.IsCompleted() || result.IsFailed() {
|
||||||
|
finalResult = result
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果提供了进度回调,则调用
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalResult == nil {
|
||||||
|
return nil, errors.New("validation completed without final result")
|
||||||
|
}
|
||||||
|
|
||||||
|
return finalResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭客户端连接.
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
if c.connLB != nil {
|
||||||
|
return c.connLB.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLowLevelOperationClient 获取底层的操作gRPC客户端(用于高级用户自定义操作)
|
||||||
|
// 注意:使用负载均衡时,每次调用此方法将返回轮询的下一个客户端.
|
||||||
|
func (c *Client) GetLowLevelOperationClient() pb.OperationValidationServiceClient {
|
||||||
|
return c.connLB.Next().opClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLowLevelRecordClient 获取底层的记录gRPC客户端(用于高级用户自定义操作)
|
||||||
|
// 注意:使用负载均衡时,每次调用此方法将返回轮询的下一个客户端.
|
||||||
|
func (c *Client) GetLowLevelRecordClient() pb.RecordValidationServiceClient {
|
||||||
|
return c.connLB.Next().recClient
|
||||||
|
}
|
||||||
627
api/queryclient/client_test.go
Normal file
627
api/queryclient/client_test.go
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
package queryclient_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/queryclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
const bufSize = 1024 * 1024
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // 测试文件中的全局变量是可接受的
|
||||||
|
var testLogger = logger.NewLogger(logr.Discard())
|
||||||
|
|
||||||
|
// mockOperationServer 模拟操作验证服务.
|
||||||
|
type mockOperationServer struct {
|
||||||
|
pb.UnimplementedOperationValidationServiceServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockOperationServer) ListOperations(
|
||||||
|
_ context.Context,
|
||||||
|
_ *pb.ListOperationReq,
|
||||||
|
) (*pb.ListOperationRes, error) {
|
||||||
|
return &pb.ListOperationRes{
|
||||||
|
Count: 2,
|
||||||
|
Data: []*pb.OperationData{
|
||||||
|
{
|
||||||
|
OpId: "op-1",
|
||||||
|
Timestamp: timestamppb.Now(),
|
||||||
|
OpSource: "test",
|
||||||
|
OpType: "create",
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "tester",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
OpId: "op-2",
|
||||||
|
Timestamp: timestamppb.Now(),
|
||||||
|
OpSource: "test",
|
||||||
|
OpType: "update",
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: "repo",
|
||||||
|
Doid: "test/repo/456",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "tester",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockOperationServer) ValidateOperation(
|
||||||
|
req *pb.ValidationReq,
|
||||||
|
stream pb.OperationValidationService_ValidateOperationServer,
|
||||||
|
) error {
|
||||||
|
// 发送进度消息
|
||||||
|
_ = stream.Send(&pb.ValidationStreamRes{
|
||||||
|
Code: 100,
|
||||||
|
Msg: "Processing",
|
||||||
|
Progress: "50%",
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送完成消息
|
||||||
|
_ = stream.Send(&pb.ValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Data: &pb.OperationData{
|
||||||
|
OpId: req.GetOpId(),
|
||||||
|
Timestamp: req.GetTime(),
|
||||||
|
OpSource: "test",
|
||||||
|
OpType: req.GetOpType(),
|
||||||
|
DoPrefix: "test",
|
||||||
|
DoRepository: req.GetDoRepository(),
|
||||||
|
Doid: "test/repo/123",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
OpActor: "tester",
|
||||||
|
},
|
||||||
|
Proof: &pb.Proof{
|
||||||
|
ColItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "hash1", Left: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockRecordServer 模拟记录验证服务.
|
||||||
|
type mockRecordServer struct {
|
||||||
|
pb.UnimplementedRecordValidationServiceServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockRecordServer) ListRecords(
|
||||||
|
_ context.Context,
|
||||||
|
_ *pb.ListRecordReq,
|
||||||
|
) (*pb.ListRecordRes, error) {
|
||||||
|
return &pb.ListRecordRes{
|
||||||
|
Count: 2,
|
||||||
|
Data: []*pb.RecordData{
|
||||||
|
{
|
||||||
|
Id: "rec-1",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
Timestamp: timestamppb.Now(),
|
||||||
|
Operator: "tester",
|
||||||
|
RcType: "log",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "rec-2",
|
||||||
|
DoPrefix: "test",
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
Timestamp: timestamppb.Now(),
|
||||||
|
Operator: "tester",
|
||||||
|
RcType: "log",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockRecordServer) ValidateRecord(
|
||||||
|
req *pb.RecordValidationReq,
|
||||||
|
stream pb.RecordValidationService_ValidateRecordServer,
|
||||||
|
) error {
|
||||||
|
// 发送进度消息
|
||||||
|
_ = stream.Send(&pb.RecordValidationStreamRes{
|
||||||
|
Code: 100,
|
||||||
|
Msg: "Processing",
|
||||||
|
Progress: "50%",
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送完成消息
|
||||||
|
_ = stream.Send(&pb.RecordValidationStreamRes{
|
||||||
|
Code: 200,
|
||||||
|
Msg: "Completed",
|
||||||
|
Progress: "100%",
|
||||||
|
Result: &pb.RecordData{
|
||||||
|
Id: req.GetRecordId(),
|
||||||
|
DoPrefix: req.GetDoPrefix(),
|
||||||
|
ProducerId: "producer-1",
|
||||||
|
Timestamp: req.GetTimestamp(),
|
||||||
|
Operator: "tester",
|
||||||
|
RcType: req.GetRcType(),
|
||||||
|
},
|
||||||
|
Proof: &pb.Proof{
|
||||||
|
ColItems: []*pb.MerkleTreeProofItem{
|
||||||
|
{Floor: 1, Hash: "hash1", Left: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTestServer 创建测试用的 gRPC server.
|
||||||
|
func setupTestServer(t *testing.T) (*grpc.Server, *bufconn.Listener) {
|
||||||
|
lis := bufconn.Listen(bufSize)
|
||||||
|
s := grpc.NewServer()
|
||||||
|
pb.RegisterOperationValidationServiceServer(s, &mockOperationServer{})
|
||||||
|
pb.RegisterRecordValidationServiceServer(s, &mockRecordServer{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.Serve(lis); err != nil {
|
||||||
|
t.Logf("Server exited with error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTestClient 创建用于测试的客户端.
|
||||||
|
//
|
||||||
|
//nolint:unparam // 集成测试暂时跳过,返回值始终为 nil
|
||||||
|
func createTestClient(t *testing.T, _ *bufconn.Listener) *queryclient.Client {
|
||||||
|
// 使用 bufconn 的特殊方式创建客户端
|
||||||
|
// 由于我们不能直接注入连接,需要通过地址的方式
|
||||||
|
// 这里我们使用一个变通的方法:直接构建客户端结构(不推荐生产使用)
|
||||||
|
// 更好的方法是提供一个可注入连接的构造函数
|
||||||
|
|
||||||
|
// 暂时使用真实的地址测试配置验证
|
||||||
|
client, err := queryclient.NewClient(queryclient.ClientConfig{
|
||||||
|
ServerAddr: "bufnet",
|
||||||
|
}, testLogger)
|
||||||
|
|
||||||
|
// 对于这个测试,我们关闭它并使用 mock 方式
|
||||||
|
if client != nil {
|
||||||
|
_ = client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 err 避免未使用的警告
|
||||||
|
_ = err
|
||||||
|
|
||||||
|
// 返回 nil,让调用者知道需要用其他方式测试
|
||||||
|
t.Skip("Skipping integration test - requires real gRPC server setup")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClient(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config queryclient.ClientConfig
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "使用ServerAddr成功创建客户端",
|
||||||
|
config: queryclient.ClientConfig{
|
||||||
|
ServerAddr: "localhost:9090",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "使用ServerAddrs成功创建客户端",
|
||||||
|
config: queryclient.ClientConfig{
|
||||||
|
ServerAddrs: []string{"localhost:9090", "localhost:9091"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "没有提供地址应该失败",
|
||||||
|
config: queryclient.ClientConfig{},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "at least one server address is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
client, err := queryclient.NewClient(tt.config, testLogger)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
}
|
||||||
|
assert.Nil(t, client)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, client)
|
||||||
|
// 清理
|
||||||
|
if client != nil {
|
||||||
|
_ = client.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConfig_GetAddrs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config queryclient.ClientConfig
|
||||||
|
wantAddrs []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ServerAddrs优先",
|
||||||
|
config: queryclient.ClientConfig{
|
||||||
|
ServerAddrs: []string{"addr1:9090", "addr2:9090"},
|
||||||
|
ServerAddr: "addr3:9090",
|
||||||
|
},
|
||||||
|
wantAddrs: []string{"addr1:9090", "addr2:9090"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "使用ServerAddr作为后备",
|
||||||
|
config: queryclient.ClientConfig{
|
||||||
|
ServerAddr: "addr1:9090",
|
||||||
|
},
|
||||||
|
wantAddrs: []string{"addr1:9090"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "没有地址应该返回错误",
|
||||||
|
config: queryclient.ClientConfig{},
|
||||||
|
wantAddrs: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addrs, err := tt.config.GetAddrs()
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.wantAddrs, addrs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListOperationsRequest(t *testing.T) {
|
||||||
|
// 测试请求结构的创建
|
||||||
|
now := time.Now()
|
||||||
|
req := queryclient.ListOperationsRequest{
|
||||||
|
PageSize: 10,
|
||||||
|
PreTime: now,
|
||||||
|
Timestamp: &now,
|
||||||
|
OpSource: model.Source("test"),
|
||||||
|
OpType: model.Type("create"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(10), req.PageSize)
|
||||||
|
assert.Equal(t, now, req.PreTime)
|
||||||
|
assert.NotNil(t, req.Timestamp)
|
||||||
|
assert.Equal(t, "test", string(req.OpSource))
|
||||||
|
assert.Equal(t, "create", string(req.OpType))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidationRequest(t *testing.T) {
|
||||||
|
// 测试验证请求结构
|
||||||
|
now := time.Now()
|
||||||
|
req := queryclient.ValidationRequest{
|
||||||
|
Time: now,
|
||||||
|
OpID: "op-123",
|
||||||
|
OpType: "create",
|
||||||
|
DoRepository: "repo",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, now, req.Time)
|
||||||
|
assert.Equal(t, "op-123", req.OpID)
|
||||||
|
assert.Equal(t, "create", req.OpType)
|
||||||
|
assert.Equal(t, "repo", req.DoRepository)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRecordsRequest(t *testing.T) {
|
||||||
|
// 测试记录列表请求结构
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
req := queryclient.ListRecordsRequest{
|
||||||
|
PageSize: 20,
|
||||||
|
PreTime: now,
|
||||||
|
DoPrefix: "test",
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(20), req.PageSize)
|
||||||
|
assert.Equal(t, now, req.PreTime)
|
||||||
|
assert.Equal(t, "test", req.DoPrefix)
|
||||||
|
assert.Equal(t, "log", req.RCType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordValidationRequest(t *testing.T) {
|
||||||
|
// 测试记录验证请求结构
|
||||||
|
now := time.Now()
|
||||||
|
req := queryclient.RecordValidationRequest{
|
||||||
|
Timestamp: now,
|
||||||
|
RecordID: "rec-123",
|
||||||
|
DoPrefix: "test",
|
||||||
|
RCType: "log",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, now, req.Timestamp)
|
||||||
|
assert.Equal(t, "rec-123", req.RecordID)
|
||||||
|
assert.Equal(t, "test", req.DoPrefix)
|
||||||
|
assert.Equal(t, "log", req.RCType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 集成测试部分(需要真实的 gRPC server).
|
||||||
|
func TestIntegration_ListOperations(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := client.ListOperations(ctx, queryclient.ListOperationsRequest{
|
||||||
|
PageSize: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, int64(2), resp.Count)
|
||||||
|
assert.Len(t, resp.Data, 2)
|
||||||
|
assert.Equal(t, "op-1", resp.Data[0].OpID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ValidateOperation(t *testing.T) { //nolint:dupl // 测试代码中的重复模式是合理的
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resultChan, err := client.ValidateOperation(ctx, queryclient.ValidationRequest{
|
||||||
|
Time: time.Now(),
|
||||||
|
OpID: "op-test",
|
||||||
|
OpType: "create",
|
||||||
|
DoRepository: "repo",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultChan)
|
||||||
|
|
||||||
|
results := []int32{}
|
||||||
|
for result := range resultChan {
|
||||||
|
results = append(results, result.Code)
|
||||||
|
if result.IsCompleted() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, results, int32(100)) // Processing
|
||||||
|
assert.Contains(t, results, int32(200)) // Completed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ValidateOperationSync(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
progressCount := 0
|
||||||
|
result, err := client.ValidateOperationSync(
|
||||||
|
ctx,
|
||||||
|
queryclient.ValidationRequest{
|
||||||
|
Time: time.Now(),
|
||||||
|
OpID: "op-test",
|
||||||
|
OpType: "create",
|
||||||
|
DoRepository: "repo",
|
||||||
|
},
|
||||||
|
func(r *model.ValidationResult) {
|
||||||
|
progressCount++
|
||||||
|
assert.Equal(t, int32(100), r.Code)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
assert.True(t, result.IsCompleted())
|
||||||
|
assert.Positive(t, progressCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ListRecords(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := client.ListRecords(ctx, queryclient.ListRecordsRequest{
|
||||||
|
PageSize: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, int64(2), resp.Count)
|
||||||
|
assert.Len(t, resp.Data, 2)
|
||||||
|
assert.Equal(t, "rec-1", resp.Data[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ValidateRecord(t *testing.T) { //nolint:dupl // 测试代码中的重复模式是合理的
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resultChan, err := client.ValidateRecord(ctx, queryclient.RecordValidationRequest{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
RecordID: "rec-test",
|
||||||
|
DoPrefix: "test",
|
||||||
|
RCType: "log",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultChan)
|
||||||
|
|
||||||
|
results := []int32{}
|
||||||
|
for result := range resultChan {
|
||||||
|
results = append(results, result.Code)
|
||||||
|
if result.IsCompleted() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, results, int32(100)) // Processing
|
||||||
|
assert.Contains(t, results, int32(200)) // Completed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ValidateRecordSync(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
progressCount := 0
|
||||||
|
result, err := client.ValidateRecordSync(
|
||||||
|
ctx,
|
||||||
|
queryclient.RecordValidationRequest{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
RecordID: "rec-test",
|
||||||
|
DoPrefix: "test",
|
||||||
|
RCType: "log",
|
||||||
|
},
|
||||||
|
func(r *model.RecordValidationResult) {
|
||||||
|
progressCount++
|
||||||
|
assert.Equal(t, int32(100), r.Code)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, int32(200), result.Code)
|
||||||
|
assert.True(t, result.IsCompleted())
|
||||||
|
assert.Positive(t, progressCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetLowLevelClients(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
opClient := client.GetLowLevelOperationClient()
|
||||||
|
assert.NotNil(t, opClient)
|
||||||
|
|
||||||
|
recClient := client.GetLowLevelRecordClient()
|
||||||
|
assert.NotNil(t, recClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_Close(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, lis := setupTestServer(t)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
client := createTestClient(t, lis)
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := client.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 再次关闭应该不会报错
|
||||||
|
err = client.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
18
cookiecutter-config-file.yml
Normal file
18
cookiecutter-config-file.yml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# This project was generated using a Cookiecutter template:
|
||||||
|
# https://github.com/daotl/go-template
|
||||||
|
|
||||||
|
cookiecutter_inputs:
|
||||||
|
project_name: "go-trustlog"
|
||||||
|
go_module_path: "gitea.internetapi.cn/trustlog-sd"
|
||||||
|
license_owner: "example"
|
||||||
|
base_branch: "main"
|
||||||
|
contact_email: ""
|
||||||
|
project_description: "TrustlogSdk is a Go application created using https://github.com/daotl/go-template"
|
||||||
|
github_specific_features: "n"
|
||||||
|
private_project: "y"
|
||||||
|
use_codecov: "y"
|
||||||
|
use_lefthook: "y"
|
||||||
|
use_precommit: "y"
|
||||||
|
go_version: "1.24"
|
||||||
|
go_toolchain_version: "1.24.5"
|
||||||
|
license: "MIT"
|
||||||
97
go.mod
Normal file
97
go.mod
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
module go.yandata.net/iod/iod/go-trustlog
|
||||||
|
|
||||||
|
go 1.25
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/ThreeDotsLabs/watermill v1.5.1
|
||||||
|
github.com/apache/pulsar-client-go v0.17.0
|
||||||
|
github.com/crpt/go-crpt v1.0.0
|
||||||
|
github.com/fxamacker/cbor/v2 v2.7.0
|
||||||
|
github.com/go-logr/logr v1.4.3
|
||||||
|
github.com/go-playground/validator/v10 v10.28.0
|
||||||
|
github.com/minio/sha256-simd v1.0.1
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/zeebo/blake3 v0.2.4
|
||||||
|
golang.org/x/crypto v0.43.0
|
||||||
|
google.golang.org/grpc v1.75.0
|
||||||
|
google.golang.org/protobuf v1.36.8
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect
|
||||||
|
github.com/99designs/keyring v1.2.1 // indirect
|
||||||
|
github.com/AthenZ/athenz v1.12.13 // indirect
|
||||||
|
github.com/DataDog/zstd v1.5.0 // indirect
|
||||||
|
github.com/ardielle/ardielle-go v1.5.2 // indirect
|
||||||
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/bits-and-blooms/bitset v1.4.0 // indirect
|
||||||
|
github.com/btcsuite/btcd v0.22.0-beta // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/crpt/go-merkle v0.0.0-20211202024952-07ef5d0dcfc0 // indirect
|
||||||
|
github.com/danieljoos/wincred v1.1.2 // indirect
|
||||||
|
github.com/daotl/go-acei v0.0.0-20211201154418-8daef5059165 // indirect
|
||||||
|
github.com/daotl/guts v0.0.0-20211209102048-f83c8ade78e8 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/dvsekhvalnov/jose2go v1.6.0 // indirect
|
||||||
|
github.com/emmansun/gmsm v0.40.0 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.1 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
|
github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect
|
||||||
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
|
||||||
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect
|
||||||
|
github.com/hamba/avro/v2 v2.29.0 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/libp2p/go-msgio v0.1.0 // indirect
|
||||||
|
github.com/lithammer/shortuuid/v3 v3.0.7 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/mr-tron/base58 v1.2.0 // indirect
|
||||||
|
github.com/mtibben/percent v0.2.1 // indirect
|
||||||
|
github.com/multiformats/go-multihash v0.2.3 // indirect
|
||||||
|
github.com/multiformats/go-varint v0.0.6 // indirect
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
|
github.com/oasisprotocol/curve25519-voi v0.0.0-20211129104401-1d84291be125 // indirect
|
||||||
|
github.com/oklog/ulid v1.3.1 // indirect
|
||||||
|
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
|
||||||
|
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
github.com/prometheus/client_golang v1.23.0 // indirect
|
||||||
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
|
github.com/prometheus/common v0.65.0 // indirect
|
||||||
|
github.com/prometheus/procfs v0.17.0 // indirect
|
||||||
|
github.com/sasha-s/go-deadlock v0.2.1-0.20190427202633-1595213edefa // indirect
|
||||||
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
|
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
|
github.com/tendermint/tendermint v0.35.0 // indirect
|
||||||
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
|
golang.org/x/mod v0.29.0 // indirect
|
||||||
|
golang.org/x/net v0.46.0 // indirect
|
||||||
|
golang.org/x/oauth2 v0.30.0 // indirect
|
||||||
|
golang.org/x/sys v0.37.0 // indirect
|
||||||
|
golang.org/x/term v0.36.0 // indirect
|
||||||
|
golang.org/x/text v0.30.0 // indirect
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect
|
||||||
|
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
k8s.io/apimachinery v0.32.3 // indirect
|
||||||
|
k8s.io/client-go v0.32.3 // indirect
|
||||||
|
k8s.io/klog/v2 v2.130.1 // indirect
|
||||||
|
k8s.io/utils v0.0.0-20250321185631-1f6e0b77f77e // indirect
|
||||||
|
lukechampine.com/blake3 v1.1.6 // indirect
|
||||||
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
|
||||||
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect
|
||||||
|
sigs.k8s.io/yaml v1.4.0 // indirect
|
||||||
|
)
|
||||||
30
internal/grpcclient/config.go
Normal file
30
internal/grpcclient/config.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package grpcclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config 客户端配置.
|
||||||
|
type Config struct {
|
||||||
|
// ServerAddrs gRPC服务器地址列表,格式: "host:port"
|
||||||
|
// 支持多个地址,客户端将使用轮询负载均衡
|
||||||
|
ServerAddrs []string
|
||||||
|
// ServerAddr 单个服务器地址(向后兼容),如果设置了此字段,将忽略ServerAddrs
|
||||||
|
ServerAddr string
|
||||||
|
// DialOptions 额外的gRPC拨号选项
|
||||||
|
DialOptions []grpc.DialOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAddrs 获取服务器地址列表.
|
||||||
|
func (c *Config) GetAddrs() ([]string, error) {
|
||||||
|
switch {
|
||||||
|
case len(c.ServerAddrs) > 0:
|
||||||
|
return c.ServerAddrs, nil
|
||||||
|
case c.ServerAddr != "":
|
||||||
|
return []string{c.ServerAddr}, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New("at least one server address is required")
|
||||||
|
}
|
||||||
|
}
|
||||||
119
internal/grpcclient/config_test.go
Normal file
119
internal/grpcclient/config_test.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package grpcclient_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/grpcclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_GetAddrs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config grpcclient.Config
|
||||||
|
wantAddrs []string
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ServerAddrs优先级高于ServerAddr",
|
||||||
|
config: grpcclient.Config{
|
||||||
|
ServerAddrs: []string{"server1:9090", "server2:9090"},
|
||||||
|
ServerAddr: "server3:9090",
|
||||||
|
},
|
||||||
|
wantAddrs: []string{"server1:9090", "server2:9090"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "只有ServerAddrs",
|
||||||
|
config: grpcclient.Config{
|
||||||
|
ServerAddrs: []string{"server1:9090", "server2:9090", "server3:9090"},
|
||||||
|
},
|
||||||
|
wantAddrs: []string{"server1:9090", "server2:9090", "server3:9090"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "只有ServerAddr",
|
||||||
|
config: grpcclient.Config{
|
||||||
|
ServerAddr: "server1:9090",
|
||||||
|
},
|
||||||
|
wantAddrs: []string{"server1:9090"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ServerAddrs为空,使用ServerAddr",
|
||||||
|
config: grpcclient.Config{
|
||||||
|
ServerAddrs: []string{},
|
||||||
|
ServerAddr: "server1:9090",
|
||||||
|
},
|
||||||
|
wantAddrs: []string{"server1:9090"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "没有任何地址应该返回错误",
|
||||||
|
config: grpcclient.Config{},
|
||||||
|
wantAddrs: nil,
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "at least one server address is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ServerAddrs为空且ServerAddr为空",
|
||||||
|
config: grpcclient.Config{
|
||||||
|
ServerAddrs: []string{},
|
||||||
|
ServerAddr: "",
|
||||||
|
},
|
||||||
|
wantAddrs: nil,
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "at least one server address is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addrs, err := tt.config.GetAddrs()
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
}
|
||||||
|
assert.Nil(t, addrs)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.wantAddrs, addrs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_EmptyServerAddrs(t *testing.T) {
|
||||||
|
// 测试空的 ServerAddrs 切片
|
||||||
|
config := grpcclient.Config{
|
||||||
|
ServerAddrs: []string{},
|
||||||
|
ServerAddr: "fallback:9090",
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, err := config.GetAddrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"fallback:9090"}, addrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_MultipleServerAddrs(t *testing.T) {
|
||||||
|
// 测试多个服务器地址
|
||||||
|
config := grpcclient.Config{
|
||||||
|
ServerAddrs: []string{
|
||||||
|
"server1:9090",
|
||||||
|
"server2:9091",
|
||||||
|
"server3:9092",
|
||||||
|
"server4:9093",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, err := config.GetAddrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, addrs, 4)
|
||||||
|
assert.Equal(t, "server1:9090", addrs[0])
|
||||||
|
assert.Equal(t, "server4:9093", addrs[3])
|
||||||
|
}
|
||||||
113
internal/grpcclient/loadbalancer.go
Normal file
113
internal/grpcclient/loadbalancer.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package grpcclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientFactory 客户端工厂函数类型.
|
||||||
|
type ClientFactory[T any] func(grpc.ClientConnInterface) T
|
||||||
|
|
||||||
|
// ServerClient 封装单个服务器的连接.
|
||||||
|
type ServerClient[T any] struct {
|
||||||
|
addr string
|
||||||
|
conn *grpc.ClientConn
|
||||||
|
client T
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadBalancer 轮询负载均衡器(泛型版本).
|
||||||
|
type LoadBalancer[T any] struct {
|
||||||
|
servers []*ServerClient[T]
|
||||||
|
counter atomic.Uint64
|
||||||
|
mu sync.RWMutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLoadBalancer 创建新的负载均衡器.
|
||||||
|
func NewLoadBalancer[T any](
|
||||||
|
addrs []string,
|
||||||
|
dialOpts []grpc.DialOption,
|
||||||
|
factory ClientFactory[T],
|
||||||
|
) (*LoadBalancer[T], error) {
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return nil, errors.New("at least one server address is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
lb := &LoadBalancer[T]{
|
||||||
|
servers: make([]*ServerClient[T], 0, len(addrs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认使用不安全的连接(生产环境应使用TLS)
|
||||||
|
opts := []grpc.DialOption{
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
}
|
||||||
|
opts = append(opts, dialOpts...)
|
||||||
|
|
||||||
|
// 连接所有服务器
|
||||||
|
for _, addr := range addrs {
|
||||||
|
conn, err := grpc.NewClient(addr, opts...)
|
||||||
|
if err != nil {
|
||||||
|
// 关闭已创建的连接
|
||||||
|
_ = lb.Close()
|
||||||
|
return nil, fmt.Errorf("failed to connect to server %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := factory(conn)
|
||||||
|
lb.servers = append(lb.servers, &ServerClient[T]{
|
||||||
|
addr: addr,
|
||||||
|
conn: conn,
|
||||||
|
client: client,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return lb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next 使用轮询算法获取下一个客户端.
|
||||||
|
func (lb *LoadBalancer[T]) Next() T {
|
||||||
|
lb.mu.RLock()
|
||||||
|
defer lb.mu.RUnlock()
|
||||||
|
|
||||||
|
if len(lb.servers) == 0 || lb.closed {
|
||||||
|
var zero T
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原子递增计数器并取模
|
||||||
|
idx := lb.counter.Add(1) % uint64(len(lb.servers))
|
||||||
|
return lb.servers[idx].client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭所有连接.
|
||||||
|
func (lb *LoadBalancer[T]) Close() error {
|
||||||
|
lb.mu.Lock()
|
||||||
|
defer lb.mu.Unlock()
|
||||||
|
|
||||||
|
// 如果已经关闭,直接返回
|
||||||
|
if lb.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for _, server := range lb.servers {
|
||||||
|
if err := server.conn.Close(); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记为已关闭
|
||||||
|
lb.closed = true
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerCount 返回服务器数量.
|
||||||
|
func (lb *LoadBalancer[T]) ServerCount() int {
|
||||||
|
lb.mu.RLock()
|
||||||
|
defer lb.mu.RUnlock()
|
||||||
|
return len(lb.servers)
|
||||||
|
}
|
||||||
186
internal/grpcclient/loadbalancer_test.go
Normal file
186
internal/grpcclient/loadbalancer_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package grpcclient_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/grpcclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockClient 用于测试的模拟客户端.
|
||||||
|
type mockClient struct {
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLoadBalancer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addrs []string
|
||||||
|
dialOpts []grpc.DialOption
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "成功创建负载均衡器",
|
||||||
|
addrs: []string{
|
||||||
|
"localhost:9090",
|
||||||
|
"localhost:9091",
|
||||||
|
},
|
||||||
|
dialOpts: []grpc.DialOption{
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "没有地址应该失败",
|
||||||
|
addrs: []string{},
|
||||||
|
dialOpts: nil,
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "at least one server address is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil地址列表应该失败",
|
||||||
|
addrs: nil,
|
||||||
|
dialOpts: nil,
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "at least one server address is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
lb, err := grpcclient.NewLoadBalancer(
|
||||||
|
tt.addrs,
|
||||||
|
tt.dialOpts,
|
||||||
|
func(_ grpc.ClientConnInterface) *mockClient {
|
||||||
|
return &mockClient{ID: "test"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
}
|
||||||
|
assert.Nil(t, lb)
|
||||||
|
} else {
|
||||||
|
// 注意:这里会实际尝试连接,在测试环境下可能失败
|
||||||
|
// 实际使用时应该使用 mock 或 bufconn
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Skipping test - cannot connect to servers: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, lb)
|
||||||
|
assert.Equal(t, len(tt.addrs), lb.ServerCount())
|
||||||
|
// 清理
|
||||||
|
_ = lb.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadBalancer_Next(t *testing.T) {
|
||||||
|
// 创建一个模拟的负载均衡器,不需要真实连接
|
||||||
|
t.Run("轮询算法测试", func(t *testing.T) {
|
||||||
|
// 这个测试需要使用 bufconn 或其他 mock 方式
|
||||||
|
// 暂时跳过需要真实连接的测试
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping test that requires network connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs := []string{"localhost:9090", "localhost:9091", "localhost:9092"}
|
||||||
|
lb, err := grpcclient.NewLoadBalancer(
|
||||||
|
addrs,
|
||||||
|
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
|
||||||
|
func(_ grpc.ClientConnInterface) *mockClient {
|
||||||
|
return &mockClient{ID: "test"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Cannot create load balancer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer lb.Close()
|
||||||
|
|
||||||
|
// 测试轮询:调用 Next() 多次应该轮询返回不同的客户端
|
||||||
|
clients := make([]*mockClient, 6)
|
||||||
|
for i := range 6 {
|
||||||
|
clients[i] = lb.Next()
|
||||||
|
assert.NotNil(t, clients[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadBalancer_Close(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping test that requires network connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs := []string{"localhost:9090"}
|
||||||
|
lb, err := grpcclient.NewLoadBalancer(
|
||||||
|
addrs,
|
||||||
|
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
|
||||||
|
func(_ grpc.ClientConnInterface) *mockClient {
|
||||||
|
return &mockClient{ID: "test"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Cannot create load balancer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一次关闭
|
||||||
|
err = lb.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 再次关闭应该也不会报错
|
||||||
|
err = lb.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadBalancer_ServerCount(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping test that requires network connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addrs []string
|
||||||
|
wantCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "单服务器",
|
||||||
|
addrs: []string{"localhost:9090"},
|
||||||
|
wantCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "多服务器",
|
||||||
|
addrs: []string{"localhost:9090", "localhost:9091", "localhost:9092"},
|
||||||
|
wantCount: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
lb, err := grpcclient.NewLoadBalancer(
|
||||||
|
tt.addrs,
|
||||||
|
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
|
||||||
|
func(_ grpc.ClientConnInterface) *mockClient {
|
||||||
|
return &mockClient{ID: "test"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Cannot create load balancer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer lb.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantCount, lb.ServerCount())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
48
internal/helpers/cbor.go
Normal file
48
internal/helpers/cbor.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//nolint:gochecknoglobals // 使用 sync.Once 模式需要全局变量来确保单次初始化
|
||||||
|
canonicalEncModeOnce sync.Once
|
||||||
|
canonicalEncMode cbor.EncMode //nolint:gochecknoglobals // 使用 sync.Once 模式需要全局变量来确保单次初始化
|
||||||
|
errCanonicalEncMode error
|
||||||
|
)
|
||||||
|
|
||||||
|
// getCanonicalEncMode 获取 Canonical CBOR 编码模式。
|
||||||
|
// 使用 Canonical CBOR 编码模式,确保序列化结果的一致性。
|
||||||
|
// Canonical CBOR 遵循 RFC 7049 Section 3.9,保证相同数据在不同实现间产生相同的字节序列。
|
||||||
|
// 使用 TimeRFC3339Nano 模式确保 time.Time 的纳秒精度被完整保留。
|
||||||
|
func getCanonicalEncMode() (cbor.EncMode, error) {
|
||||||
|
canonicalEncModeOnce.Do(func() {
|
||||||
|
opts := cbor.CanonicalEncOptions()
|
||||||
|
// 设置时间编码模式为 RFC3339Nano,以保留纳秒精度
|
||||||
|
opts.Time = cbor.TimeRFC3339Nano
|
||||||
|
canonicalEncMode, errCanonicalEncMode = opts.EncMode()
|
||||||
|
if errCanonicalEncMode != nil {
|
||||||
|
errCanonicalEncMode = fmt.Errorf("failed to create canonical CBOR encoder: %w", errCanonicalEncMode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return canonicalEncMode, errCanonicalEncMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalCanonical 使用 Canonical CBOR 编码序列化数据。
|
||||||
|
// 确保相同数据在不同实现间产生相同的字节序列,适用于需要确定性序列化的场景。
|
||||||
|
func MarshalCanonical(v interface{}) ([]byte, error) {
|
||||||
|
encMode, err := getCanonicalEncMode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return encMode.Marshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal 反序列化 CBOR 数据。
|
||||||
|
// 支持标准 CBOR 和 Canonical CBOR 格式。
|
||||||
|
func Unmarshal(data []byte, v interface{}) error {
|
||||||
|
return cbor.Unmarshal(data, v)
|
||||||
|
}
|
||||||
177
internal/helpers/cbor_test.go
Normal file
177
internal/helpers/cbor_test.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package helpers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMarshalCanonical(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int",
|
||||||
|
input: 42,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map",
|
||||||
|
input: map[string]interface{}{"key": "value"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slice",
|
||||||
|
input: []string{"a", "b", "c"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct",
|
||||||
|
input: struct{ Name string }{"test"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
input: nil,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
result, err := helpers.MarshalCanonical(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalCanonical_Deterministic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
input := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
"key3": 123,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err1 := helpers.MarshalCanonical(input)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
|
||||||
|
result2, err2 := helpers.MarshalCanonical(input)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
|
||||||
|
// Canonical encoding should produce identical results
|
||||||
|
assert.Equal(t, result1, result2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
target interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
data: []byte{0x64, 0x74, 0x65, 0x73, 0x74}, // "test" in CBOR
|
||||||
|
target: new(string),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int",
|
||||||
|
data: []byte{0x18, 0x2a}, // 42 in CBOR
|
||||||
|
target: new(int),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid CBOR",
|
||||||
|
data: []byte{0xff, 0xff, 0xff},
|
||||||
|
target: new(string),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty data",
|
||||||
|
data: []byte{},
|
||||||
|
target: new(string),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
err := helpers.Unmarshal(tt.data, tt.target)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshal_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test string",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int",
|
||||||
|
input: 42,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map",
|
||||||
|
input: map[string]interface{}{"key": "value"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slice",
|
||||||
|
input: []string{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := helpers.MarshalCanonical(tt.input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var result interface{}
|
||||||
|
err = helpers.Unmarshal(data, &result)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
76
internal/helpers/cbor_time_test.go
Normal file
76
internal/helpers/cbor_time_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package helpers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCBORTimePrecision(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// 创建一个包含纳秒精度的时间戳
|
||||||
|
originalTime := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||||
|
|
||||||
|
t.Logf("Original time: %v", originalTime)
|
||||||
|
t.Logf("Original nanoseconds: %d", originalTime.Nanosecond())
|
||||||
|
|
||||||
|
// 序列化
|
||||||
|
data, err := helpers.MarshalCanonical(originalTime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// 反序列化
|
||||||
|
var decodedTime time.Time
|
||||||
|
err = helpers.Unmarshal(data, &decodedTime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Decoded time: %v", decodedTime)
|
||||||
|
t.Logf("Decoded nanoseconds: %d", decodedTime.Nanosecond())
|
||||||
|
|
||||||
|
// 验证纳秒精度是否保留
|
||||||
|
assert.Equal(t, originalTime.UnixNano(), decodedTime.UnixNano(),
|
||||||
|
"纳秒精度应该被保留")
|
||||||
|
assert.Equal(t, originalTime.Nanosecond(), decodedTime.Nanosecond(),
|
||||||
|
"纳秒部分应该相等")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCBORTimePrecision_Struct(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type TestStruct struct {
|
||||||
|
Timestamp time.Time `cbor:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建一个包含纳秒精度的时间戳
|
||||||
|
originalTime := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||||
|
original := TestStruct{
|
||||||
|
Timestamp: originalTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Original timestamp: %v", original.Timestamp)
|
||||||
|
t.Logf("Original nanoseconds: %d", original.Timestamp.Nanosecond())
|
||||||
|
|
||||||
|
// 序列化
|
||||||
|
data, err := helpers.MarshalCanonical(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
|
||||||
|
// 反序列化
|
||||||
|
var decoded TestStruct
|
||||||
|
err = helpers.Unmarshal(data, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Decoded timestamp: %v", decoded.Timestamp)
|
||||||
|
t.Logf("Decoded nanoseconds: %d", decoded.Timestamp.Nanosecond())
|
||||||
|
|
||||||
|
// 验证纳秒精度是否保留
|
||||||
|
assert.Equal(t, original.Timestamp.UnixNano(), decoded.Timestamp.UnixNano(),
|
||||||
|
"纳秒精度应该被保留")
|
||||||
|
assert.Equal(t, original.Timestamp.Nanosecond(), decoded.Timestamp.Nanosecond(),
|
||||||
|
"纳秒部分应该相等")
|
||||||
|
}
|
||||||
146
internal/helpers/tlv.go
Normal file
146
internal/helpers/tlv.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TLVReader 提供 TLV(Type-Length-Value)格式的顺序读取能力。
|
||||||
|
// 支持无需反序列化全部报文即可读取特定字段。
|
||||||
|
type TLVReader struct {
|
||||||
|
r io.Reader
|
||||||
|
br io.ByteReader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLVReader 创建新的 TLVReader。
|
||||||
|
func NewTLVReader(r io.Reader) *TLVReader {
|
||||||
|
return &TLVReader{
|
||||||
|
r: r,
|
||||||
|
br: newByteReader(r),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadField 读取下一个 TLV 字段。
|
||||||
|
// 返回字段的长度和值。
|
||||||
|
func (tr *TLVReader) ReadField() ([]byte, error) {
|
||||||
|
length, err := readVarint(tr.br)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read field length: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if length == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value := make([]byte, length)
|
||||||
|
if _, errRead := io.ReadFull(tr.r, value); errRead != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read field value: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadStringField 读取下一个 TLV 字段并转换为字符串。
|
||||||
|
func (tr *TLVReader) ReadStringField() (string, error) {
|
||||||
|
data, err := tr.ReadField()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLVWriter 提供 TLV 格式的顺序写入能力。
|
||||||
|
type TLVWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLVWriter 创建新的 TLVWriter。
|
||||||
|
func NewTLVWriter(w io.Writer) *TLVWriter {
|
||||||
|
return &TLVWriter{w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteField 写入一个 TLV 字段。
|
||||||
|
func (tw *TLVWriter) WriteField(value []byte) error {
|
||||||
|
if err := writeVarint(tw.w, uint64(len(value))); err != nil {
|
||||||
|
return fmt.Errorf("failed to write field length: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(value) > 0 {
|
||||||
|
if _, err := tw.w.Write(value); err != nil {
|
||||||
|
return fmt.Errorf("failed to write field value: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteStringField 写入一个字符串 TLV 字段。
|
||||||
|
func (tw *TLVWriter) WriteStringField(value string) error {
|
||||||
|
return tw.WriteField([]byte(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Varint 编码/解码函数
|
||||||
|
|
||||||
|
const (
|
||||||
|
// varintContinueBit 表示 varint 还有后续字节的标志位。
|
||||||
|
varintContinueBit = 0x80
|
||||||
|
// varintDataMask 用于提取 varint 数据位的掩码。
|
||||||
|
varintDataMask = 0x7f
|
||||||
|
// varintMaxShift 表示 varint 最大的位移量,防止溢出。
|
||||||
|
varintMaxShift = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeVarint 写入变长整数(类似 Protobuf 的 varint 编码)。
|
||||||
|
// 将 uint64 编码为变长格式,节省存储空间。
|
||||||
|
//
|
||||||
|
|
||||||
|
func writeVarint(w io.Writer, x uint64) error {
|
||||||
|
var buf [10]byte
|
||||||
|
n := 0
|
||||||
|
for x >= varintContinueBit {
|
||||||
|
buf[n] = byte(x) | varintContinueBit
|
||||||
|
x >>= 7
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
buf[n] = byte(x)
|
||||||
|
_, err := w.Write(buf[:n+1])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// readVarint 读取变长整数。
|
||||||
|
// 从字节流中解码 varint 格式的整数。
|
||||||
|
func readVarint(r io.ByteReader) (uint64, error) {
|
||||||
|
var x uint64
|
||||||
|
var shift uint
|
||||||
|
for {
|
||||||
|
b, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
x |= uint64(b&varintDataMask) << shift
|
||||||
|
if b&varintContinueBit == 0 {
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
shift += 7
|
||||||
|
if shift >= varintMaxShift {
|
||||||
|
return 0, errors.New("varint overflow")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// byteReader 为 io.Reader 实现 io.ByteReader 接口。
|
||||||
|
// 提供逐字节读取能力,用于 varint 解码。
|
||||||
|
type byteReader struct {
|
||||||
|
r io.Reader
|
||||||
|
b [1]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newByteReader(r io.Reader) io.ByteReader {
|
||||||
|
return &byteReader{r: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *byteReader) ReadByte() (byte, error) {
|
||||||
|
_, err := br.r.Read(br.b[:])
|
||||||
|
return br.b[0], err
|
||||||
|
}
|
||||||
267
internal/helpers/tlv_test.go
Normal file
267
internal/helpers/tlv_test.go
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
package helpers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTLVReader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
r := bytes.NewReader([]byte{})
|
||||||
|
reader := helpers.NewTLVReader(r)
|
||||||
|
assert.NotNil(t, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTLVWriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
assert.NotNil(t, writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLVWriter_WriteField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal field",
|
||||||
|
value: []byte("test"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty field",
|
||||||
|
value: []byte{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large field",
|
||||||
|
value: make([]byte, 1000),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
err := writer.WriteField(tt.value)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
if len(tt.value) > 0 {
|
||||||
|
assert.Positive(t, buf.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLVWriter_WriteStringField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
|
||||||
|
err := writer.WriteStringField("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Positive(t, buf.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLVReader_ReadField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func() *helpers.TLVReader
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal field",
|
||||||
|
setup: func() *helpers.TLVReader {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
_ = writer.WriteField([]byte("test"))
|
||||||
|
return helpers.NewTLVReader(&buf)
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty field",
|
||||||
|
setup: func() *helpers.TLVReader {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
_ = writer.WriteField([]byte{})
|
||||||
|
return helpers.NewTLVReader(&buf)
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid data",
|
||||||
|
setup: func() *helpers.TLVReader {
|
||||||
|
return helpers.NewTLVReader(bytes.NewReader([]byte{0xff, 0xff}))
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty reader",
|
||||||
|
setup: func() *helpers.TLVReader {
|
||||||
|
return helpers.NewTLVReader(bytes.NewReader([]byte{}))
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
reader := tt.setup()
|
||||||
|
result, err := reader.ReadField()
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Empty field returns nil
|
||||||
|
if tt.name == "empty field" {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLVReader_ReadStringField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
err := writer.WriteStringField("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
reader := helpers.NewTLVReader(&buf)
|
||||||
|
result, err := reader.ReadStringField()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "test", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLV_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal",
|
||||||
|
value: []byte("test"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
value: []byte{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large",
|
||||||
|
value: make([]byte, 100),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Write
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
err := writer.WriteField(tt.value)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Read
|
||||||
|
reader := helpers.NewTLVReader(&buf)
|
||||||
|
result, err := reader.ReadField()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
// Empty byte slice returns nil from ReadField
|
||||||
|
if len(tt.value) == 0 {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.value, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLV_MultipleFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
|
||||||
|
// Write multiple fields
|
||||||
|
fields := [][]byte{
|
||||||
|
[]byte("field1"),
|
||||||
|
[]byte("field2"),
|
||||||
|
[]byte("field3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range fields {
|
||||||
|
err := writer.WriteField(field)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read multiple fields
|
||||||
|
reader := helpers.NewTLVReader(&buf)
|
||||||
|
for i, expected := range fields {
|
||||||
|
result, err := reader.ReadField()
|
||||||
|
require.NoError(t, err, "field %d", i)
|
||||||
|
assert.Equal(t, expected, result, "field %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLV_StringRoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
|
||||||
|
original := "test string"
|
||||||
|
err := writer.WriteStringField(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
reader := helpers.NewTLVReader(&buf)
|
||||||
|
result, err := reader.ReadStringField()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, original, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLVReader_ReadField_EOF(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := helpers.NewTLVWriter(&buf)
|
||||||
|
_ = writer.WriteField([]byte("test"))
|
||||||
|
|
||||||
|
reader := helpers.NewTLVReader(&buf)
|
||||||
|
_, err := reader.ReadField()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Try to read beyond EOF - this will fail when trying to read varint length
|
||||||
|
_, err = reader.ReadField()
|
||||||
|
require.Error(t, err)
|
||||||
|
// Error could be EOF or other read error
|
||||||
|
assert.Contains(t, err.Error(), "failed to read")
|
||||||
|
}
|
||||||
66
internal/helpers/uuid.go
Normal file
66
internal/helpers/uuid.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UUID v7 格式常量.
|
||||||
|
uuidRandomBytesSize = 10 // UUID中随机字节部分的大小
|
||||||
|
uuidVersion7 = 0x7000 // UUID v7的版本位
|
||||||
|
uuidVariant = 0x80 // UUID的变体位
|
||||||
|
uuidTimeMask = 0xFFFF // 时间戳掩码
|
||||||
|
uuidTimeShift = 16 // 时间戳位移
|
||||||
|
uuidVariantMask = 0x3F // 变体掩码
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewUUIDv7 生成 UUID v7 并去除连字符.
|
||||||
|
func NewUUIDv7() string {
|
||||||
|
// 获取当前时间戳(Unix 毫秒时间戳)
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
|
|
||||||
|
// 生成随机字节
|
||||||
|
randBytes := make([]byte, uuidRandomBytesSize)
|
||||||
|
_, err := rand.Read(randBytes)
|
||||||
|
if err != nil {
|
||||||
|
// 如果随机数生成失败,使用时间戳加一些伪随机值作为备选方案
|
||||||
|
return fmt.Sprintf("%016x%016x", now, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 版本和变体位
|
||||||
|
// 版本: 0x7 (0111) << 12
|
||||||
|
// 变体: 0x2 (10) << 6
|
||||||
|
versionVariant := uint16(uuidVersion7 | uuidVariant)
|
||||||
|
|
||||||
|
// 构建 UUID 字节数组
|
||||||
|
var uuid [16]byte
|
||||||
|
|
||||||
|
// 时间戳低32位 (4 bytes)
|
||||||
|
//nolint:gosec // UUID格式要求的类型转换
|
||||||
|
binary.BigEndian.PutUint32(uuid[0:4], uint32(now>>uuidTimeShift))
|
||||||
|
|
||||||
|
// 时间戳中16位 + 版本 (2 bytes)
|
||||||
|
//nolint:gosec // UUID格式要求的类型转换
|
||||||
|
binary.BigEndian.PutUint16(uuid[4:6], uint16(now&uuidTimeMask))
|
||||||
|
|
||||||
|
// 时间戳高16位 + 变体 (2 bytes)
|
||||||
|
binary.BigEndian.PutUint16(uuid[6:8], versionVariant)
|
||||||
|
|
||||||
|
// 随机数部分 (8 bytes)
|
||||||
|
copy(uuid[8:16], randBytes[:8])
|
||||||
|
|
||||||
|
// 设置变体位 (第8个字节的高两位为10)
|
||||||
|
uuid[8] = (uuid[8] & uuidVariantMask) | uuidVariant
|
||||||
|
|
||||||
|
// 转换为十六进制字符串并去除连字符
|
||||||
|
return fmt.Sprintf("%08x%04x%04x%04x%08x%04x",
|
||||||
|
binary.BigEndian.Uint32(uuid[0:4]),
|
||||||
|
binary.BigEndian.Uint16(uuid[4:6]),
|
||||||
|
binary.BigEndian.Uint16(uuid[6:8]),
|
||||||
|
binary.BigEndian.Uint16(uuid[8:10]),
|
||||||
|
binary.BigEndian.Uint32(uuid[10:14]),
|
||||||
|
binary.BigEndian.Uint16(uuid[14:16]))
|
||||||
|
}
|
||||||
151
internal/helpers/uuid_test.go
Normal file
151
internal/helpers/uuid_test.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package helpers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewUUIDv7(t *testing.T) {
|
||||||
|
// UUID v7 格式:无连字符,32个十六进制字符
|
||||||
|
uuidPattern := regexp.MustCompile(`^[0-9a-f]{32}$`)
|
||||||
|
|
||||||
|
t.Run("生成有效的UUID", func(t *testing.T) {
|
||||||
|
uuid := helpers.NewUUIDv7()
|
||||||
|
|
||||||
|
// 验证格式
|
||||||
|
assert.Len(t, uuid, 32, "UUID长度应该是32个字符")
|
||||||
|
assert.Regexp(t, uuidPattern, uuid, "UUID应该只包含小写十六进制字符")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("每次生成的UUID应该不同", func(t *testing.T) {
|
||||||
|
uuid1 := helpers.NewUUIDv7()
|
||||||
|
uuid2 := helpers.NewUUIDv7()
|
||||||
|
uuid3 := helpers.NewUUIDv7()
|
||||||
|
|
||||||
|
assert.NotEqual(t, uuid1, uuid2)
|
||||||
|
assert.NotEqual(t, uuid2, uuid3)
|
||||||
|
assert.NotEqual(t, uuid1, uuid3)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UUID格式验证", func(t *testing.T) {
|
||||||
|
uuid := helpers.NewUUIDv7()
|
||||||
|
|
||||||
|
// UUID v7 应该是 32 个十六进制字符
|
||||||
|
require.Len(t, uuid, 32)
|
||||||
|
|
||||||
|
// 检查每个字符都是有效的十六进制
|
||||||
|
for i, c := range uuid {
|
||||||
|
assert.True(t,
|
||||||
|
(c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
|
||||||
|
"字符 %c 在位置 %d 不是有效的十六进制字符", c, i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("并发生成UUID", func(t *testing.T) {
|
||||||
|
const concurrency = 100
|
||||||
|
uuids := make([]string, concurrency)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(concurrency)
|
||||||
|
|
||||||
|
for i := range concurrency {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
uuids[idx] = helpers.NewUUIDv7()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// 验证所有 UUID 都不为空且格式正确
|
||||||
|
for i, uuid := range uuids {
|
||||||
|
assert.NotEmpty(t, uuid, "UUID %d 不应该为空", i)
|
||||||
|
assert.Regexp(t, uuidPattern, uuid, "UUID %d 格式不正确", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证所有 UUID 都是唯一的
|
||||||
|
uniqueMap := make(map[string]bool)
|
||||||
|
for _, uuid := range uuids {
|
||||||
|
assert.False(t, uniqueMap[uuid], "UUID重复: %s", uuid)
|
||||||
|
uniqueMap[uuid] = true
|
||||||
|
}
|
||||||
|
assert.Len(t, uniqueMap, concurrency, "应该生成%d个唯一的UUID", concurrency)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UUID包含时间戳信息", func(t *testing.T) {
|
||||||
|
// 连续生成多个UUID,它们的时间戳部分应该相近或递增
|
||||||
|
uuid1 := helpers.NewUUIDv7()
|
||||||
|
uuid2 := helpers.NewUUIDv7()
|
||||||
|
|
||||||
|
// UUID v7 的前12个字符主要是时间戳
|
||||||
|
// 在很短的时间内生成的UUID,时间戳部分应该相同或非常接近
|
||||||
|
timePrefix1 := uuid1[:12]
|
||||||
|
timePrefix2 := uuid2[:12]
|
||||||
|
|
||||||
|
// 时间戳应该相同或第二个略大(因为时间在递增)
|
||||||
|
assert.True(t,
|
||||||
|
timePrefix1 == timePrefix2 || timePrefix1 <= timePrefix2,
|
||||||
|
"UUID的时间戳部分应该单调递增")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("批量生成UUID性能测试", func(t *testing.T) {
|
||||||
|
const iterations = 1000
|
||||||
|
uuids := make([]string, iterations)
|
||||||
|
|
||||||
|
for i := range iterations {
|
||||||
|
uuids[i] = helpers.NewUUIDv7()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证所有UUID都有效
|
||||||
|
for i, uuid := range uuids {
|
||||||
|
assert.Regexp(t, uuidPattern, uuid, "UUID %d 格式不正确", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简单的唯一性检查
|
||||||
|
uniqueMap := make(map[string]bool)
|
||||||
|
for _, uuid := range uuids {
|
||||||
|
uniqueMap[uuid] = true
|
||||||
|
}
|
||||||
|
assert.Len(t, uniqueMap, iterations, "应该生成%d个唯一的UUID", iterations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUUIDv7_Format(t *testing.T) {
|
||||||
|
// 测试UUID v7的具体格式要求
|
||||||
|
uuid := helpers.NewUUIDv7()
|
||||||
|
|
||||||
|
// 总长度 32
|
||||||
|
assert.Len(t, uuid, 32)
|
||||||
|
|
||||||
|
// 全部小写
|
||||||
|
for _, c := range uuid {
|
||||||
|
if c >= 'a' && c <= 'f' {
|
||||||
|
assert.True(t, c >= 'a' && c <= 'f')
|
||||||
|
} else {
|
||||||
|
assert.True(t, c >= '0' && c <= '9')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUUIDv7_EdgeCases(t *testing.T) {
|
||||||
|
t.Run("快速连续生成", func(t *testing.T) {
|
||||||
|
// 在极短时间内生成多个UUID
|
||||||
|
uuids := make([]string, 10)
|
||||||
|
for i := range 10 {
|
||||||
|
uuids[i] = helpers.NewUUIDv7()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有UUID应该都有效且唯一
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, uuid := range uuids {
|
||||||
|
assert.Len(t, uuid, 32)
|
||||||
|
assert.False(t, seen[uuid], "UUID不应该重复")
|
||||||
|
seen[uuid] = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
20
internal/helpers/validate.go
Normal file
20
internal/helpers/validate.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // 单例模式需要全局变量
|
||||||
|
var (
|
||||||
|
validate *validator.Validate
|
||||||
|
once sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetValidator() *validator.Validate {
|
||||||
|
once.Do(func() {
|
||||||
|
validate = validator.New()
|
||||||
|
})
|
||||||
|
return validate
|
||||||
|
}
|
||||||
186
internal/helpers/validate_test.go
Normal file
186
internal/helpers/validate_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package helpers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetValidator(t *testing.T) {
|
||||||
|
t.Run("返回有效的validator实例", func(t *testing.T) {
|
||||||
|
v := helpers.GetValidator()
|
||||||
|
|
||||||
|
require.NotNil(t, v, "Validator不应该为nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("单例模式:多次调用返回同一个实例", func(t *testing.T) {
|
||||||
|
v1 := helpers.GetValidator()
|
||||||
|
v2 := helpers.GetValidator()
|
||||||
|
v3 := helpers.GetValidator()
|
||||||
|
|
||||||
|
// 使用指针比较,确保是同一个实例
|
||||||
|
assert.Same(t, v1, v2, "第一次和第二次调用应该返回同一个实例")
|
||||||
|
assert.Same(t, v2, v3, "第二次和第三次调用应该返回同一个实例")
|
||||||
|
assert.Same(t, v1, v3, "第一次和第三次调用应该返回同一个实例")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("并发获取validator应该安全", func(t *testing.T) {
|
||||||
|
const concurrency = 100
|
||||||
|
validators := make([]interface{}, concurrency)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(concurrency)
|
||||||
|
|
||||||
|
for i := range concurrency {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
v := helpers.GetValidator()
|
||||||
|
// 存储validator实例
|
||||||
|
validators[idx] = v
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// 验证所有goroutine获取的是同一个实例
|
||||||
|
firstValidator := validators[0]
|
||||||
|
for i := 1; i < concurrency; i++ {
|
||||||
|
assert.Same(t, firstValidator, validators[i],
|
||||||
|
"并发调用第%d次获取的validator应该与第一次相同", i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("validator可以正常工作", func(t *testing.T) {
|
||||||
|
v := helpers.GetValidator()
|
||||||
|
|
||||||
|
// 测试一个简单的结构体验证
|
||||||
|
type TestStruct struct {
|
||||||
|
Name string `validate:"required,min=2,max=10"`
|
||||||
|
Email string `validate:"required,email"`
|
||||||
|
Age int `validate:"gte=0,lte=120"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 有效的结构体
|
||||||
|
validData := TestStruct{
|
||||||
|
Name: "John",
|
||||||
|
Email: "john@example.com",
|
||||||
|
Age: 30,
|
||||||
|
}
|
||||||
|
err := v.Struct(validData)
|
||||||
|
require.NoError(t, err, "有效的数据不应该产生验证错误")
|
||||||
|
|
||||||
|
// 无效的结构体 - 缺少必填字段
|
||||||
|
invalidData1 := TestStruct{
|
||||||
|
Name: "",
|
||||||
|
Age: 30,
|
||||||
|
}
|
||||||
|
err = v.Struct(invalidData1)
|
||||||
|
require.Error(t, err, "缺少必填字段应该产生验证错误")
|
||||||
|
|
||||||
|
// 无效的结构体 - 字段值超出范围
|
||||||
|
invalidData2 := TestStruct{
|
||||||
|
Name: "John",
|
||||||
|
Email: "john@example.com",
|
||||||
|
Age: 150,
|
||||||
|
}
|
||||||
|
err = v.Struct(invalidData2)
|
||||||
|
require.Error(t, err, "年龄超出范围应该产生验证错误")
|
||||||
|
|
||||||
|
// 无效的结构体 - 邮箱格式错误
|
||||||
|
invalidData3 := TestStruct{
|
||||||
|
Name: "John",
|
||||||
|
Email: "invalid-email",
|
||||||
|
Age: 30,
|
||||||
|
}
|
||||||
|
err = v.Struct(invalidData3)
|
||||||
|
assert.Error(t, err, "无效的邮箱格式应该产生验证错误")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValidator_InitializationOnce(t *testing.T) {
|
||||||
|
// 这个测试验证 sync.Once 确保初始化只执行一次
|
||||||
|
const calls = 1000
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(calls)
|
||||||
|
|
||||||
|
results := make([]interface{}, calls)
|
||||||
|
|
||||||
|
for i := range calls {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
v := helpers.GetValidator()
|
||||||
|
results[idx] = v
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// 所有结果应该指向同一个实例
|
||||||
|
first := results[0]
|
||||||
|
for i := 1; i < calls; i++ {
|
||||||
|
assert.Same(t, first, results[i],
|
||||||
|
"所有调用应该返回完全相同的validator实例")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValidator_ValidatorFunctionality(t *testing.T) {
|
||||||
|
v := helpers.GetValidator()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "结构体字段验证-成功",
|
||||||
|
data: struct {
|
||||||
|
Field string `validate:"required"`
|
||||||
|
}{
|
||||||
|
Field: "value",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "结构体字段验证-失败",
|
||||||
|
data: struct {
|
||||||
|
Field string `validate:"required"`
|
||||||
|
}{
|
||||||
|
Field: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "数字范围验证-成功",
|
||||||
|
data: struct {
|
||||||
|
Count int `validate:"min=1,max=100"`
|
||||||
|
}{
|
||||||
|
Count: 50,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "数字范围验证-失败",
|
||||||
|
data: struct {
|
||||||
|
Count int `validate:"min=1,max=100"`
|
||||||
|
}{
|
||||||
|
Count: 200,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := v.Struct(tt.data)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
194
internal/logger/logger.go
Normal file
194
internal/logger/logger.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ThreeDotsLabs/watermill"
|
||||||
|
"github.com/apache/pulsar-client-go/pulsar/log"
|
||||||
|
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// argsPerField 每个字段转换为args时的参数数量(key+value).
|
||||||
|
argsPerField = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type WatermillLoggerAdapter struct {
|
||||||
|
logger logger.Logger
|
||||||
|
fields watermill.LogFields
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WatermillLoggerAdapter) Error(msg string, err error, fields watermill.LogFields) {
|
||||||
|
allFields := mergeFields(w.fields, fields)
|
||||||
|
args := allFieldsToArgs(allFields)
|
||||||
|
w.logger.Error(fmt.Sprintf("%s: %v", msg, err), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WatermillLoggerAdapter) Info(msg string, fields watermill.LogFields) {
|
||||||
|
allFields := mergeFields(w.fields, fields)
|
||||||
|
args := allFieldsToArgs(allFields)
|
||||||
|
w.logger.Info(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WatermillLoggerAdapter) Debug(msg string, fields watermill.LogFields) {
|
||||||
|
allFields := mergeFields(w.fields, fields)
|
||||||
|
args := allFieldsToArgs(allFields)
|
||||||
|
w.logger.Debug(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WatermillLoggerAdapter) Trace(msg string, fields watermill.LogFields) {
|
||||||
|
allFields := mergeFields(w.fields, fields)
|
||||||
|
args := allFieldsToArgs(allFields)
|
||||||
|
w.logger.Debug(fmt.Sprintf("[TRACE] %s", msg), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WatermillLoggerAdapter) With(fields watermill.LogFields) watermill.LoggerAdapter {
|
||||||
|
newFields := mergeFields(w.fields, fields)
|
||||||
|
return WatermillLoggerAdapter{
|
||||||
|
logger: w.logger,
|
||||||
|
fields: newFields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:funcorder // 构造函数放在此处更符合代码组织
|
||||||
|
func NewWatermillLoggerAdapter(logger logger.Logger) *WatermillLoggerAdapter {
|
||||||
|
return &WatermillLoggerAdapter{logger: logger, fields: watermill.LogFields{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeFields(base, extra watermill.LogFields) watermill.LogFields {
|
||||||
|
merged := make(watermill.LogFields, len(base)+len(extra))
|
||||||
|
for k, v := range base {
|
||||||
|
merged[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range extra {
|
||||||
|
merged[k] = v
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func allFieldsToArgs(fields watermill.LogFields) []any {
|
||||||
|
args := make([]any, 0, len(fields)*argsPerField)
|
||||||
|
for k, v := range fields {
|
||||||
|
args = append(args, k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// ================= PulsarLoggerAdapter ======================
|
||||||
|
|
||||||
|
type PulsarLoggerAdapter struct {
|
||||||
|
logger logger.Logger
|
||||||
|
fields log.Fields
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPulsarLoggerAdapter(l logger.Logger) *PulsarLoggerAdapter {
|
||||||
|
return &PulsarLoggerAdapter{logger: l, fields: log.Fields{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) SubLogger(fields log.Fields) log.Logger {
|
||||||
|
return PulsarLoggerAdapter{
|
||||||
|
logger: p.logger,
|
||||||
|
fields: mergePulsarFields(p.fields, fields),
|
||||||
|
err: p.err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) WithFields(fields log.Fields) log.Entry {
|
||||||
|
return PulsarLoggerAdapter{
|
||||||
|
logger: p.logger,
|
||||||
|
fields: mergePulsarFields(p.fields, fields),
|
||||||
|
err: p.err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) WithField(name string, value interface{}) log.Entry {
|
||||||
|
newFields := mergePulsarFields(p.fields, log.Fields{name: value})
|
||||||
|
return PulsarLoggerAdapter{
|
||||||
|
logger: p.logger,
|
||||||
|
fields: newFields,
|
||||||
|
err: p.err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) WithError(err error) log.Entry {
|
||||||
|
return PulsarLoggerAdapter{
|
||||||
|
logger: p.logger,
|
||||||
|
fields: p.fields,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Debug(args ...interface{}) {
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
p.logger.Debug(fmt.Sprint(args...), fieldsArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Info(args ...interface{}) {
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
p.logger.Info(fmt.Sprint(args...), fieldsArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Warn(args ...interface{}) {
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
p.logger.Warn(fmt.Sprint(args...), fieldsArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Error(args ...interface{}) {
|
||||||
|
msg := fmt.Sprint(args...)
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
if p.err != nil {
|
||||||
|
// 将error作为key-value对添加到args中
|
||||||
|
fieldsArgs = append(fieldsArgs, "error", p.err)
|
||||||
|
p.logger.Error(msg, fieldsArgs...)
|
||||||
|
} else {
|
||||||
|
p.logger.Error(msg, fieldsArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Debugf(format string, args ...interface{}) {
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
p.logger.Debug(fmt.Sprintf(format, args...), fieldsArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Infof(format string, args ...interface{}) {
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
p.logger.Info(fmt.Sprintf(format, args...), fieldsArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Warnf(format string, args ...interface{}) {
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
p.logger.Warn(fmt.Sprintf(format, args...), fieldsArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PulsarLoggerAdapter) Errorf(format string, args ...interface{}) {
|
||||||
|
msg := fmt.Sprintf(format, args...)
|
||||||
|
fieldsArgs := fieldsToArgs(p.fields)
|
||||||
|
if p.err != nil {
|
||||||
|
p.logger.Error(fmt.Sprintf("%s: %v", msg, p.err), fieldsArgs...)
|
||||||
|
} else {
|
||||||
|
p.logger.Error(msg, fieldsArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并 Pulsar log.Fields.
|
||||||
|
func mergePulsarFields(base, extra log.Fields) log.Fields {
|
||||||
|
merged := make(log.Fields, len(base)+len(extra))
|
||||||
|
for k, v := range base {
|
||||||
|
merged[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range extra {
|
||||||
|
merged[k] = v
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 Pulsar log.Fields 转为 args ...any 形式,适配 Adapter.
|
||||||
|
func fieldsToArgs(fields log.Fields) []any {
|
||||||
|
args := make([]any, 0, len(fields)*argsPerField)
|
||||||
|
for k, v := range fields {
|
||||||
|
args = append(args, k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
385
internal/logger/logger_test.go
Normal file
385
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
package logger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
apilogger "go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||||
|
"go.yandata.net/iod/iod/trustlog-sdk/internal/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewWatermillLoggerAdapter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
assert.NotNil(t, adapter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_Error(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
err := errors.New("test error")
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Error("error message", err, fields)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_Info(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Info("info message", fields)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_Debug(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Debug("debug message", fields)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_Trace(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Trace("trace message", fields)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_With(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
fields1 := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
}
|
||||||
|
fields2 := map[string]interface{}{
|
||||||
|
"key2": "value2",
|
||||||
|
}
|
||||||
|
|
||||||
|
newAdapter := adapter.With(fields1)
|
||||||
|
assert.NotNil(t, newAdapter)
|
||||||
|
|
||||||
|
// Test that fields are merged
|
||||||
|
newAdapter2 := newAdapter.With(fields2)
|
||||||
|
assert.NotNil(t, newAdapter2)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
newAdapter2.Info("test", map[string]interface{}{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPulsarLoggerAdapter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
assert.NotNil(t, adapter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Debug(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Debug("debug message")
|
||||||
|
adapter.Debug("debug", "message", "with", "args")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Info(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Info("info message")
|
||||||
|
adapter.Info("info", "message", "with", "args")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Warn(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Warn("warn message")
|
||||||
|
adapter.Warn("warn", "message", "with", "args")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Error(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
test func()
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without error",
|
||||||
|
test: func() {
|
||||||
|
adapter.Error("error message")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with error",
|
||||||
|
test: func() {
|
||||||
|
adapterWithErr := adapter.WithError(errors.New("test error"))
|
||||||
|
adapterWithErr.Error("error message")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.NotPanics(t, tt.test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Debugf(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Debugf("debug %s", "message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Infof(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Infof("info %s", "message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Warnf(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Warnf("warn %s", "message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_Errorf(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
test func()
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without error",
|
||||||
|
test: func() {
|
||||||
|
adapter.Errorf("error %s", "message")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with error",
|
||||||
|
test: func() {
|
||||||
|
adapterWithErr := adapter.WithError(errors.New("test error"))
|
||||||
|
adapterWithErr.Errorf("error %s", "message")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.NotPanics(t, tt.test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_SubLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
}
|
||||||
|
|
||||||
|
subLogger := adapter.SubLogger(fields)
|
||||||
|
assert.NotNil(t, subLogger)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
subLogger.Info("test")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_WithFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := adapter.WithFields(fields)
|
||||||
|
assert.NotNil(t, entry)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
entry.Info("test")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_WithField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
entry := adapter.WithField("key", "value")
|
||||||
|
assert.NotNil(t, entry)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
entry.Info("test")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_WithError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
err := errors.New("test error")
|
||||||
|
entry := adapter.WithError(err)
|
||||||
|
assert.NotNil(t, entry)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
entry.Error("test error message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_ChainedFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
entry1 := adapter.WithField("key1", "value1")
|
||||||
|
entry2 := entry1.WithField("key2", "value2")
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
entry2.Info("chained fields test")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test WithError separately
|
||||||
|
entryWithErr := adapter.WithError(errors.New("test error"))
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
entryWithErr.Error("chained fields test")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulsarLoggerAdapter_FormatMethods(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Debugf("debug %d", 1)
|
||||||
|
adapter.Infof("info %d", 2)
|
||||||
|
adapter.Warnf("warn %d", 3)
|
||||||
|
adapter.Errorf("error %d", 4)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_EmptyFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapter.Error("error", errors.New("test"), map[string]interface{}{})
|
||||||
|
adapter.Info("info", map[string]interface{}{})
|
||||||
|
adapter.Debug("debug", map[string]interface{}{})
|
||||||
|
adapter.Trace("trace", map[string]interface{}{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatermillLoggerAdapter_MergedFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := apilogger.NewNopLogger()
|
||||||
|
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||||
|
|
||||||
|
baseFields := map[string]interface{}{
|
||||||
|
"base": "value",
|
||||||
|
}
|
||||||
|
extraFields := map[string]interface{}{
|
||||||
|
"extra": "value",
|
||||||
|
}
|
||||||
|
|
||||||
|
adapterWithBase := adapter.With(baseFields)
|
||||||
|
adapterWithBoth := adapterWithBase.With(extraFields)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
adapterWithBoth.Info("test", map[string]interface{}{})
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user