refactor: 重构trustlog-sdk目录结构到trustlog/go-trustlog
- 将所有trustlog-sdk文件移动到trustlog/go-trustlog/目录 - 更新README中所有import路径从trustlog-sdk改为go-trustlog - 更新cookiecutter配置文件中的项目名称 - 更新根目录.lefthook.yml以引用新位置的配置 - 添加go.sum文件到版本控制 - 删除过时的示例文件 这次重构与trustlog-server保持一致的目录结构, 为未来支持多语言SDK(Python、Java等)预留空间。
This commit is contained in:
30
internal/grpcclient/config.go
Normal file
30
internal/grpcclient/config.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package grpcclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Config 客户端配置.
|
||||
type Config struct {
|
||||
// ServerAddrs gRPC服务器地址列表,格式: "host:port"
|
||||
// 支持多个地址,客户端将使用轮询负载均衡
|
||||
ServerAddrs []string
|
||||
// ServerAddr 单个服务器地址(向后兼容),如果设置了此字段,将忽略ServerAddrs
|
||||
ServerAddr string
|
||||
// DialOptions 额外的gRPC拨号选项
|
||||
DialOptions []grpc.DialOption
|
||||
}
|
||||
|
||||
// GetAddrs 获取服务器地址列表.
|
||||
func (c *Config) GetAddrs() ([]string, error) {
|
||||
switch {
|
||||
case len(c.ServerAddrs) > 0:
|
||||
return c.ServerAddrs, nil
|
||||
case c.ServerAddr != "":
|
||||
return []string{c.ServerAddr}, nil
|
||||
default:
|
||||
return nil, errors.New("at least one server address is required")
|
||||
}
|
||||
}
|
||||
119
internal/grpcclient/config_test.go
Normal file
119
internal/grpcclient/config_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package grpcclient_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/grpcclient"
|
||||
)
|
||||
|
||||
func TestConfig_GetAddrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config grpcclient.Config
|
||||
wantAddrs []string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "ServerAddrs优先级高于ServerAddr",
|
||||
config: grpcclient.Config{
|
||||
ServerAddrs: []string{"server1:9090", "server2:9090"},
|
||||
ServerAddr: "server3:9090",
|
||||
},
|
||||
wantAddrs: []string{"server1:9090", "server2:9090"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "只有ServerAddrs",
|
||||
config: grpcclient.Config{
|
||||
ServerAddrs: []string{"server1:9090", "server2:9090", "server3:9090"},
|
||||
},
|
||||
wantAddrs: []string{"server1:9090", "server2:9090", "server3:9090"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "只有ServerAddr",
|
||||
config: grpcclient.Config{
|
||||
ServerAddr: "server1:9090",
|
||||
},
|
||||
wantAddrs: []string{"server1:9090"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ServerAddrs为空,使用ServerAddr",
|
||||
config: grpcclient.Config{
|
||||
ServerAddrs: []string{},
|
||||
ServerAddr: "server1:9090",
|
||||
},
|
||||
wantAddrs: []string{"server1:9090"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "没有任何地址应该返回错误",
|
||||
config: grpcclient.Config{},
|
||||
wantAddrs: nil,
|
||||
wantErr: true,
|
||||
errMsg: "at least one server address is required",
|
||||
},
|
||||
{
|
||||
name: "ServerAddrs为空且ServerAddr为空",
|
||||
config: grpcclient.Config{
|
||||
ServerAddrs: []string{},
|
||||
ServerAddr: "",
|
||||
},
|
||||
wantAddrs: nil,
|
||||
wantErr: true,
|
||||
errMsg: "at least one server address is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addrs, err := tt.config.GetAddrs()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
assert.Nil(t, addrs)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantAddrs, addrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_EmptyServerAddrs(t *testing.T) {
|
||||
// 测试空的 ServerAddrs 切片
|
||||
config := grpcclient.Config{
|
||||
ServerAddrs: []string{},
|
||||
ServerAddr: "fallback:9090",
|
||||
}
|
||||
|
||||
addrs, err := config.GetAddrs()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"fallback:9090"}, addrs)
|
||||
}
|
||||
|
||||
func TestConfig_MultipleServerAddrs(t *testing.T) {
|
||||
// 测试多个服务器地址
|
||||
config := grpcclient.Config{
|
||||
ServerAddrs: []string{
|
||||
"server1:9090",
|
||||
"server2:9091",
|
||||
"server3:9092",
|
||||
"server4:9093",
|
||||
},
|
||||
}
|
||||
|
||||
addrs, err := config.GetAddrs()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, addrs, 4)
|
||||
assert.Equal(t, "server1:9090", addrs[0])
|
||||
assert.Equal(t, "server4:9093", addrs[3])
|
||||
}
|
||||
113
internal/grpcclient/loadbalancer.go
Normal file
113
internal/grpcclient/loadbalancer.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package grpcclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// ClientFactory 客户端工厂函数类型.
|
||||
type ClientFactory[T any] func(grpc.ClientConnInterface) T
|
||||
|
||||
// ServerClient 封装单个服务器的连接.
|
||||
type ServerClient[T any] struct {
|
||||
addr string
|
||||
conn *grpc.ClientConn
|
||||
client T
|
||||
}
|
||||
|
||||
// LoadBalancer 轮询负载均衡器(泛型版本).
|
||||
type LoadBalancer[T any] struct {
|
||||
servers []*ServerClient[T]
|
||||
counter atomic.Uint64
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewLoadBalancer 创建新的负载均衡器.
|
||||
func NewLoadBalancer[T any](
|
||||
addrs []string,
|
||||
dialOpts []grpc.DialOption,
|
||||
factory ClientFactory[T],
|
||||
) (*LoadBalancer[T], error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("at least one server address is required")
|
||||
}
|
||||
|
||||
lb := &LoadBalancer[T]{
|
||||
servers: make([]*ServerClient[T], 0, len(addrs)),
|
||||
}
|
||||
|
||||
// 默认使用不安全的连接(生产环境应使用TLS)
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
}
|
||||
opts = append(opts, dialOpts...)
|
||||
|
||||
// 连接所有服务器
|
||||
for _, addr := range addrs {
|
||||
conn, err := grpc.NewClient(addr, opts...)
|
||||
if err != nil {
|
||||
// 关闭已创建的连接
|
||||
_ = lb.Close()
|
||||
return nil, fmt.Errorf("failed to connect to server %s: %w", addr, err)
|
||||
}
|
||||
|
||||
client := factory(conn)
|
||||
lb.servers = append(lb.servers, &ServerClient[T]{
|
||||
addr: addr,
|
||||
conn: conn,
|
||||
client: client,
|
||||
})
|
||||
}
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Next 使用轮询算法获取下一个客户端.
|
||||
func (lb *LoadBalancer[T]) Next() T {
|
||||
lb.mu.RLock()
|
||||
defer lb.mu.RUnlock()
|
||||
|
||||
if len(lb.servers) == 0 || lb.closed {
|
||||
var zero T
|
||||
return zero
|
||||
}
|
||||
|
||||
// 原子递增计数器并取模
|
||||
idx := lb.counter.Add(1) % uint64(len(lb.servers))
|
||||
return lb.servers[idx].client
|
||||
}
|
||||
|
||||
// Close 关闭所有连接.
|
||||
func (lb *LoadBalancer[T]) Close() error {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
// 如果已经关闭,直接返回
|
||||
if lb.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, server := range lb.servers {
|
||||
if err := server.conn.Close(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// 标记为已关闭
|
||||
lb.closed = true
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// ServerCount 返回服务器数量.
|
||||
func (lb *LoadBalancer[T]) ServerCount() int {
|
||||
lb.mu.RLock()
|
||||
defer lb.mu.RUnlock()
|
||||
return len(lb.servers)
|
||||
}
|
||||
186
internal/grpcclient/loadbalancer_test.go
Normal file
186
internal/grpcclient/loadbalancer_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package grpcclient_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/grpcclient"
|
||||
)
|
||||
|
||||
// mockClient 用于测试的模拟客户端.
|
||||
type mockClient struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func TestNewLoadBalancer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addrs []string
|
||||
dialOpts []grpc.DialOption
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "成功创建负载均衡器",
|
||||
addrs: []string{
|
||||
"localhost:9090",
|
||||
"localhost:9091",
|
||||
},
|
||||
dialOpts: []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "没有地址应该失败",
|
||||
addrs: []string{},
|
||||
dialOpts: nil,
|
||||
wantErr: true,
|
||||
errMsg: "at least one server address is required",
|
||||
},
|
||||
{
|
||||
name: "nil地址列表应该失败",
|
||||
addrs: nil,
|
||||
dialOpts: nil,
|
||||
wantErr: true,
|
||||
errMsg: "at least one server address is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
lb, err := grpcclient.NewLoadBalancer(
|
||||
tt.addrs,
|
||||
tt.dialOpts,
|
||||
func(_ grpc.ClientConnInterface) *mockClient {
|
||||
return &mockClient{ID: "test"}
|
||||
},
|
||||
)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
assert.Nil(t, lb)
|
||||
} else {
|
||||
// 注意:这里会实际尝试连接,在测试环境下可能失败
|
||||
// 实际使用时应该使用 mock 或 bufconn
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test - cannot connect to servers: %v", err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lb)
|
||||
assert.Equal(t, len(tt.addrs), lb.ServerCount())
|
||||
// 清理
|
||||
_ = lb.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBalancer_Next(t *testing.T) {
|
||||
// 创建一个模拟的负载均衡器,不需要真实连接
|
||||
t.Run("轮询算法测试", func(t *testing.T) {
|
||||
// 这个测试需要使用 bufconn 或其他 mock 方式
|
||||
// 暂时跳过需要真实连接的测试
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test that requires network connection")
|
||||
}
|
||||
|
||||
addrs := []string{"localhost:9090", "localhost:9091", "localhost:9092"}
|
||||
lb, err := grpcclient.NewLoadBalancer(
|
||||
addrs,
|
||||
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
|
||||
func(_ grpc.ClientConnInterface) *mockClient {
|
||||
return &mockClient{ID: "test"}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create load balancer: %v", err)
|
||||
return
|
||||
}
|
||||
defer lb.Close()
|
||||
|
||||
// 测试轮询:调用 Next() 多次应该轮询返回不同的客户端
|
||||
clients := make([]*mockClient, 6)
|
||||
for i := range 6 {
|
||||
clients[i] = lb.Next()
|
||||
assert.NotNil(t, clients[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadBalancer_Close(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test that requires network connection")
|
||||
}
|
||||
|
||||
addrs := []string{"localhost:9090"}
|
||||
lb, err := grpcclient.NewLoadBalancer(
|
||||
addrs,
|
||||
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
|
||||
func(_ grpc.ClientConnInterface) *mockClient {
|
||||
return &mockClient{ID: "test"}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create load balancer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 第一次关闭
|
||||
err = lb.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再次关闭应该也不会报错
|
||||
err = lb.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLoadBalancer_ServerCount(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test that requires network connection")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addrs []string
|
||||
wantCount int
|
||||
}{
|
||||
{
|
||||
name: "单服务器",
|
||||
addrs: []string{"localhost:9090"},
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "多服务器",
|
||||
addrs: []string{"localhost:9090", "localhost:9091", "localhost:9092"},
|
||||
wantCount: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
lb, err := grpcclient.NewLoadBalancer(
|
||||
tt.addrs,
|
||||
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
|
||||
func(_ grpc.ClientConnInterface) *mockClient {
|
||||
return &mockClient{ID: "test"}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create load balancer: %v", err)
|
||||
return
|
||||
}
|
||||
defer lb.Close()
|
||||
|
||||
assert.Equal(t, tt.wantCount, lb.ServerCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
48
internal/helpers/cbor.go
Normal file
48
internal/helpers/cbor.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
//nolint:gochecknoglobals // 使用 sync.Once 模式需要全局变量来确保单次初始化
|
||||
canonicalEncModeOnce sync.Once
|
||||
canonicalEncMode cbor.EncMode //nolint:gochecknoglobals // 使用 sync.Once 模式需要全局变量来确保单次初始化
|
||||
errCanonicalEncMode error
|
||||
)
|
||||
|
||||
// getCanonicalEncMode 获取 Canonical CBOR 编码模式。
|
||||
// 使用 Canonical CBOR 编码模式,确保序列化结果的一致性。
|
||||
// Canonical CBOR 遵循 RFC 7049 Section 3.9,保证相同数据在不同实现间产生相同的字节序列。
|
||||
// 使用 TimeRFC3339Nano 模式确保 time.Time 的纳秒精度被完整保留。
|
||||
func getCanonicalEncMode() (cbor.EncMode, error) {
|
||||
canonicalEncModeOnce.Do(func() {
|
||||
opts := cbor.CanonicalEncOptions()
|
||||
// 设置时间编码模式为 RFC3339Nano,以保留纳秒精度
|
||||
opts.Time = cbor.TimeRFC3339Nano
|
||||
canonicalEncMode, errCanonicalEncMode = opts.EncMode()
|
||||
if errCanonicalEncMode != nil {
|
||||
errCanonicalEncMode = fmt.Errorf("failed to create canonical CBOR encoder: %w", errCanonicalEncMode)
|
||||
}
|
||||
})
|
||||
return canonicalEncMode, errCanonicalEncMode
|
||||
}
|
||||
|
||||
// MarshalCanonical 使用 Canonical CBOR 编码序列化数据。
|
||||
// 确保相同数据在不同实现间产生相同的字节序列,适用于需要确定性序列化的场景。
|
||||
func MarshalCanonical(v interface{}) ([]byte, error) {
|
||||
encMode, err := getCanonicalEncMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encMode.Marshal(v)
|
||||
}
|
||||
|
||||
// Unmarshal 反序列化 CBOR 数据。
|
||||
// 支持标准 CBOR 和 Canonical CBOR 格式。
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
return cbor.Unmarshal(data, v)
|
||||
}
|
||||
177
internal/helpers/cbor_test.go
Normal file
177
internal/helpers/cbor_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package helpers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
func TestMarshalCanonical(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
input: 42,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
input: map[string]interface{}{"key": "value"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "slice",
|
||||
input: []string{"a", "b", "c"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
input: struct{ Name string }{"test"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
input: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := helpers.MarshalCanonical(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalCanonical_Deterministic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"key3": 123,
|
||||
}
|
||||
|
||||
result1, err1 := helpers.MarshalCanonical(input)
|
||||
require.NoError(t, err1)
|
||||
|
||||
result2, err2 := helpers.MarshalCanonical(input)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Canonical encoding should produce identical results
|
||||
assert.Equal(t, result1, result2)
|
||||
}
|
||||
|
||||
func TestUnmarshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
target interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
data: []byte{0x64, 0x74, 0x65, 0x73, 0x74}, // "test" in CBOR
|
||||
target: new(string),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
data: []byte{0x18, 0x2a}, // 42 in CBOR
|
||||
target: new(int),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid CBOR",
|
||||
data: []byte{0xff, 0xff, 0xff},
|
||||
target: new(string),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
target: new(string),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := helpers.Unmarshal(tt.data, tt.target)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshal_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
input: "test string",
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
input: 42,
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
input: map[string]interface{}{"key": "value"},
|
||||
},
|
||||
{
|
||||
name: "slice",
|
||||
input: []string{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Marshal
|
||||
data, err := helpers.MarshalCanonical(tt.input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// Unmarshal
|
||||
var result interface{}
|
||||
err = helpers.Unmarshal(data, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
assert.NotNil(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
76
internal/helpers/cbor_time_test.go
Normal file
76
internal/helpers/cbor_time_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package helpers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
func TestCBORTimePrecision(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 创建一个包含纳秒精度的时间戳
|
||||
originalTime := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||
|
||||
t.Logf("Original time: %v", originalTime)
|
||||
t.Logf("Original nanoseconds: %d", originalTime.Nanosecond())
|
||||
|
||||
// 序列化
|
||||
data, err := helpers.MarshalCanonical(originalTime)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// 反序列化
|
||||
var decodedTime time.Time
|
||||
err = helpers.Unmarshal(data, &decodedTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Decoded time: %v", decodedTime)
|
||||
t.Logf("Decoded nanoseconds: %d", decodedTime.Nanosecond())
|
||||
|
||||
// 验证纳秒精度是否保留
|
||||
assert.Equal(t, originalTime.UnixNano(), decodedTime.UnixNano(),
|
||||
"纳秒精度应该被保留")
|
||||
assert.Equal(t, originalTime.Nanosecond(), decodedTime.Nanosecond(),
|
||||
"纳秒部分应该相等")
|
||||
}
|
||||
|
||||
func TestCBORTimePrecision_Struct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type TestStruct struct {
|
||||
Timestamp time.Time `cbor:"timestamp"`
|
||||
}
|
||||
|
||||
// 创建一个包含纳秒精度的时间戳
|
||||
originalTime := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||
original := TestStruct{
|
||||
Timestamp: originalTime,
|
||||
}
|
||||
|
||||
t.Logf("Original timestamp: %v", original.Timestamp)
|
||||
t.Logf("Original nanoseconds: %d", original.Timestamp.Nanosecond())
|
||||
|
||||
// 序列化
|
||||
data, err := helpers.MarshalCanonical(original)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// 反序列化
|
||||
var decoded TestStruct
|
||||
err = helpers.Unmarshal(data, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Decoded timestamp: %v", decoded.Timestamp)
|
||||
t.Logf("Decoded nanoseconds: %d", decoded.Timestamp.Nanosecond())
|
||||
|
||||
// 验证纳秒精度是否保留
|
||||
assert.Equal(t, original.Timestamp.UnixNano(), decoded.Timestamp.UnixNano(),
|
||||
"纳秒精度应该被保留")
|
||||
assert.Equal(t, original.Timestamp.Nanosecond(), decoded.Timestamp.Nanosecond(),
|
||||
"纳秒部分应该相等")
|
||||
}
|
||||
146
internal/helpers/tlv.go
Normal file
146
internal/helpers/tlv.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// TLVReader 提供 TLV(Type-Length-Value)格式的顺序读取能力。
|
||||
// 支持无需反序列化全部报文即可读取特定字段。
|
||||
type TLVReader struct {
|
||||
r io.Reader
|
||||
br io.ByteReader
|
||||
}
|
||||
|
||||
// NewTLVReader 创建新的 TLVReader。
|
||||
func NewTLVReader(r io.Reader) *TLVReader {
|
||||
return &TLVReader{
|
||||
r: r,
|
||||
br: newByteReader(r),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadField 读取下一个 TLV 字段。
|
||||
// 返回字段的长度和值。
|
||||
func (tr *TLVReader) ReadField() ([]byte, error) {
|
||||
length, err := readVarint(tr.br)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read field length: %w", err)
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
value := make([]byte, length)
|
||||
if _, errRead := io.ReadFull(tr.r, value); errRead != nil {
|
||||
return nil, fmt.Errorf("failed to read field value: %w", errRead)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ReadStringField 读取下一个 TLV 字段并转换为字符串。
|
||||
func (tr *TLVReader) ReadStringField() (string, error) {
|
||||
data, err := tr.ReadField()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// TLVWriter 提供 TLV 格式的顺序写入能力。
|
||||
type TLVWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewTLVWriter 创建新的 TLVWriter。
|
||||
func NewTLVWriter(w io.Writer) *TLVWriter {
|
||||
return &TLVWriter{w: w}
|
||||
}
|
||||
|
||||
// WriteField 写入一个 TLV 字段。
|
||||
func (tw *TLVWriter) WriteField(value []byte) error {
|
||||
if err := writeVarint(tw.w, uint64(len(value))); err != nil {
|
||||
return fmt.Errorf("failed to write field length: %w", err)
|
||||
}
|
||||
|
||||
if len(value) > 0 {
|
||||
if _, err := tw.w.Write(value); err != nil {
|
||||
return fmt.Errorf("failed to write field value: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteStringField 写入一个字符串 TLV 字段。
|
||||
func (tw *TLVWriter) WriteStringField(value string) error {
|
||||
return tw.WriteField([]byte(value))
|
||||
}
|
||||
|
||||
// Varint 编码/解码函数
|
||||
|
||||
const (
|
||||
// varintContinueBit 表示 varint 还有后续字节的标志位。
|
||||
varintContinueBit = 0x80
|
||||
// varintDataMask 用于提取 varint 数据位的掩码。
|
||||
varintDataMask = 0x7f
|
||||
// varintMaxShift 表示 varint 最大的位移量,防止溢出。
|
||||
varintMaxShift = 64
|
||||
)
|
||||
|
||||
// writeVarint 写入变长整数(类似 Protobuf 的 varint 编码)。
|
||||
// 将 uint64 编码为变长格式,节省存储空间。
|
||||
//
|
||||
|
||||
func writeVarint(w io.Writer, x uint64) error {
|
||||
var buf [10]byte
|
||||
n := 0
|
||||
for x >= varintContinueBit {
|
||||
buf[n] = byte(x) | varintContinueBit
|
||||
x >>= 7
|
||||
n++
|
||||
}
|
||||
buf[n] = byte(x)
|
||||
_, err := w.Write(buf[:n+1])
|
||||
return err
|
||||
}
|
||||
|
||||
// readVarint 读取变长整数。
|
||||
// 从字节流中解码 varint 格式的整数。
|
||||
func readVarint(r io.ByteReader) (uint64, error) {
|
||||
var x uint64
|
||||
var shift uint
|
||||
for {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
x |= uint64(b&varintDataMask) << shift
|
||||
if b&varintContinueBit == 0 {
|
||||
return x, nil
|
||||
}
|
||||
shift += 7
|
||||
if shift >= varintMaxShift {
|
||||
return 0, errors.New("varint overflow")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// byteReader 为 io.Reader 实现 io.ByteReader 接口。
|
||||
// 提供逐字节读取能力,用于 varint 解码。
|
||||
type byteReader struct {
|
||||
r io.Reader
|
||||
b [1]byte
|
||||
}
|
||||
|
||||
func newByteReader(r io.Reader) io.ByteReader {
|
||||
return &byteReader{r: r}
|
||||
}
|
||||
|
||||
func (br *byteReader) ReadByte() (byte, error) {
|
||||
_, err := br.r.Read(br.b[:])
|
||||
return br.b[0], err
|
||||
}
|
||||
267
internal/helpers/tlv_test.go
Normal file
267
internal/helpers/tlv_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package helpers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
func TestNewTLVReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := bytes.NewReader([]byte{})
|
||||
reader := helpers.NewTLVReader(r)
|
||||
assert.NotNil(t, reader)
|
||||
}
|
||||
|
||||
func TestNewTLVWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
assert.NotNil(t, writer)
|
||||
}
|
||||
|
||||
func TestTLVWriter_WriteField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal field",
|
||||
value: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty field",
|
||||
value: []byte{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "large field",
|
||||
value: make([]byte, 1000),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
err := writer.WriteField(tt.value)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if len(tt.value) > 0 {
|
||||
assert.Positive(t, buf.Len())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLVWriter_WriteStringField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
|
||||
err := writer.WriteStringField("test")
|
||||
require.NoError(t, err)
|
||||
assert.Positive(t, buf.Len())
|
||||
}
|
||||
|
||||
func TestTLVReader_ReadField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() *helpers.TLVReader
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal field",
|
||||
setup: func() *helpers.TLVReader {
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
_ = writer.WriteField([]byte("test"))
|
||||
return helpers.NewTLVReader(&buf)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty field",
|
||||
setup: func() *helpers.TLVReader {
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
_ = writer.WriteField([]byte{})
|
||||
return helpers.NewTLVReader(&buf)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid data",
|
||||
setup: func() *helpers.TLVReader {
|
||||
return helpers.NewTLVReader(bytes.NewReader([]byte{0xff, 0xff}))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty reader",
|
||||
setup: func() *helpers.TLVReader {
|
||||
return helpers.NewTLVReader(bytes.NewReader([]byte{}))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
reader := tt.setup()
|
||||
result, err := reader.ReadField()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
// Empty field returns nil
|
||||
if tt.name == "empty field" {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLVReader_ReadStringField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
err := writer.WriteStringField("test")
|
||||
require.NoError(t, err)
|
||||
|
||||
reader := helpers.NewTLVReader(&buf)
|
||||
result, err := reader.ReadStringField()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test", result)
|
||||
}
|
||||
|
||||
func TestTLV_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value []byte
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
value: []byte("test"),
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
value: []byte{},
|
||||
},
|
||||
{
|
||||
name: "large",
|
||||
value: make([]byte, 100),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Write
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
err := writer.WriteField(tt.value)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read
|
||||
reader := helpers.NewTLVReader(&buf)
|
||||
result, err := reader.ReadField()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
// Empty byte slice returns nil from ReadField
|
||||
if len(tt.value) == 0 {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.Equal(t, tt.value, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLV_MultipleFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
|
||||
// Write multiple fields
|
||||
fields := [][]byte{
|
||||
[]byte("field1"),
|
||||
[]byte("field2"),
|
||||
[]byte("field3"),
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
err := writer.WriteField(field)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Read multiple fields
|
||||
reader := helpers.NewTLVReader(&buf)
|
||||
for i, expected := range fields {
|
||||
result, err := reader.ReadField()
|
||||
require.NoError(t, err, "field %d", i)
|
||||
assert.Equal(t, expected, result, "field %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLV_StringRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
|
||||
original := "test string"
|
||||
err := writer.WriteStringField(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
reader := helpers.NewTLVReader(&buf)
|
||||
result, err := reader.ReadStringField()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, result)
|
||||
}
|
||||
|
||||
func TestTLVReader_ReadField_EOF(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
writer := helpers.NewTLVWriter(&buf)
|
||||
_ = writer.WriteField([]byte("test"))
|
||||
|
||||
reader := helpers.NewTLVReader(&buf)
|
||||
_, err := reader.ReadField()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to read beyond EOF - this will fail when trying to read varint length
|
||||
_, err = reader.ReadField()
|
||||
require.Error(t, err)
|
||||
// Error could be EOF or other read error
|
||||
assert.Contains(t, err.Error(), "failed to read")
|
||||
}
|
||||
66
internal/helpers/uuid.go
Normal file
66
internal/helpers/uuid.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// UUID v7 格式常量.
|
||||
uuidRandomBytesSize = 10 // UUID中随机字节部分的大小
|
||||
uuidVersion7 = 0x7000 // UUID v7的版本位
|
||||
uuidVariant = 0x80 // UUID的变体位
|
||||
uuidTimeMask = 0xFFFF // 时间戳掩码
|
||||
uuidTimeShift = 16 // 时间戳位移
|
||||
uuidVariantMask = 0x3F // 变体掩码
|
||||
)
|
||||
|
||||
// NewUUIDv7 生成 UUID v7 并去除连字符.
|
||||
func NewUUIDv7() string {
|
||||
// 获取当前时间戳(Unix 毫秒时间戳)
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// 生成随机字节
|
||||
randBytes := make([]byte, uuidRandomBytesSize)
|
||||
_, err := rand.Read(randBytes)
|
||||
if err != nil {
|
||||
// 如果随机数生成失败,使用时间戳加一些伪随机值作为备选方案
|
||||
return fmt.Sprintf("%016x%016x", now, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// 版本和变体位
|
||||
// 版本: 0x7 (0111) << 12
|
||||
// 变体: 0x2 (10) << 6
|
||||
versionVariant := uint16(uuidVersion7 | uuidVariant)
|
||||
|
||||
// 构建 UUID 字节数组
|
||||
var uuid [16]byte
|
||||
|
||||
// 时间戳低32位 (4 bytes)
|
||||
//nolint:gosec // UUID格式要求的类型转换
|
||||
binary.BigEndian.PutUint32(uuid[0:4], uint32(now>>uuidTimeShift))
|
||||
|
||||
// 时间戳中16位 + 版本 (2 bytes)
|
||||
//nolint:gosec // UUID格式要求的类型转换
|
||||
binary.BigEndian.PutUint16(uuid[4:6], uint16(now&uuidTimeMask))
|
||||
|
||||
// 时间戳高16位 + 变体 (2 bytes)
|
||||
binary.BigEndian.PutUint16(uuid[6:8], versionVariant)
|
||||
|
||||
// 随机数部分 (8 bytes)
|
||||
copy(uuid[8:16], randBytes[:8])
|
||||
|
||||
// 设置变体位 (第8个字节的高两位为10)
|
||||
uuid[8] = (uuid[8] & uuidVariantMask) | uuidVariant
|
||||
|
||||
// 转换为十六进制字符串并去除连字符
|
||||
return fmt.Sprintf("%08x%04x%04x%04x%08x%04x",
|
||||
binary.BigEndian.Uint32(uuid[0:4]),
|
||||
binary.BigEndian.Uint16(uuid[4:6]),
|
||||
binary.BigEndian.Uint16(uuid[6:8]),
|
||||
binary.BigEndian.Uint16(uuid[8:10]),
|
||||
binary.BigEndian.Uint32(uuid[10:14]),
|
||||
binary.BigEndian.Uint16(uuid[14:16]))
|
||||
}
|
||||
151
internal/helpers/uuid_test.go
Normal file
151
internal/helpers/uuid_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package helpers_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
func TestNewUUIDv7(t *testing.T) {
|
||||
// UUID v7 格式:无连字符,32个十六进制字符
|
||||
uuidPattern := regexp.MustCompile(`^[0-9a-f]{32}$`)
|
||||
|
||||
t.Run("生成有效的UUID", func(t *testing.T) {
|
||||
uuid := helpers.NewUUIDv7()
|
||||
|
||||
// 验证格式
|
||||
assert.Len(t, uuid, 32, "UUID长度应该是32个字符")
|
||||
assert.Regexp(t, uuidPattern, uuid, "UUID应该只包含小写十六进制字符")
|
||||
})
|
||||
|
||||
t.Run("每次生成的UUID应该不同", func(t *testing.T) {
|
||||
uuid1 := helpers.NewUUIDv7()
|
||||
uuid2 := helpers.NewUUIDv7()
|
||||
uuid3 := helpers.NewUUIDv7()
|
||||
|
||||
assert.NotEqual(t, uuid1, uuid2)
|
||||
assert.NotEqual(t, uuid2, uuid3)
|
||||
assert.NotEqual(t, uuid1, uuid3)
|
||||
})
|
||||
|
||||
t.Run("UUID格式验证", func(t *testing.T) {
|
||||
uuid := helpers.NewUUIDv7()
|
||||
|
||||
// UUID v7 应该是 32 个十六进制字符
|
||||
require.Len(t, uuid, 32)
|
||||
|
||||
// 检查每个字符都是有效的十六进制
|
||||
for i, c := range uuid {
|
||||
assert.True(t,
|
||||
(c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
|
||||
"字符 %c 在位置 %d 不是有效的十六进制字符", c, i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("并发生成UUID", func(t *testing.T) {
|
||||
const concurrency = 100
|
||||
uuids := make([]string, concurrency)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
for i := range concurrency {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
uuids[idx] = helpers.NewUUIDv7()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 验证所有 UUID 都不为空且格式正确
|
||||
for i, uuid := range uuids {
|
||||
assert.NotEmpty(t, uuid, "UUID %d 不应该为空", i)
|
||||
assert.Regexp(t, uuidPattern, uuid, "UUID %d 格式不正确", i)
|
||||
}
|
||||
|
||||
// 验证所有 UUID 都是唯一的
|
||||
uniqueMap := make(map[string]bool)
|
||||
for _, uuid := range uuids {
|
||||
assert.False(t, uniqueMap[uuid], "UUID重复: %s", uuid)
|
||||
uniqueMap[uuid] = true
|
||||
}
|
||||
assert.Len(t, uniqueMap, concurrency, "应该生成%d个唯一的UUID", concurrency)
|
||||
})
|
||||
|
||||
t.Run("UUID包含时间戳信息", func(t *testing.T) {
|
||||
// 连续生成多个UUID,它们的时间戳部分应该相近或递增
|
||||
uuid1 := helpers.NewUUIDv7()
|
||||
uuid2 := helpers.NewUUIDv7()
|
||||
|
||||
// UUID v7 的前12个字符主要是时间戳
|
||||
// 在很短的时间内生成的UUID,时间戳部分应该相同或非常接近
|
||||
timePrefix1 := uuid1[:12]
|
||||
timePrefix2 := uuid2[:12]
|
||||
|
||||
// 时间戳应该相同或第二个略大(因为时间在递增)
|
||||
assert.True(t,
|
||||
timePrefix1 == timePrefix2 || timePrefix1 <= timePrefix2,
|
||||
"UUID的时间戳部分应该单调递增")
|
||||
})
|
||||
|
||||
t.Run("批量生成UUID性能测试", func(t *testing.T) {
|
||||
const iterations = 1000
|
||||
uuids := make([]string, iterations)
|
||||
|
||||
for i := range iterations {
|
||||
uuids[i] = helpers.NewUUIDv7()
|
||||
}
|
||||
|
||||
// 验证所有UUID都有效
|
||||
for i, uuid := range uuids {
|
||||
assert.Regexp(t, uuidPattern, uuid, "UUID %d 格式不正确", i)
|
||||
}
|
||||
|
||||
// 简单的唯一性检查
|
||||
uniqueMap := make(map[string]bool)
|
||||
for _, uuid := range uuids {
|
||||
uniqueMap[uuid] = true
|
||||
}
|
||||
assert.Len(t, uniqueMap, iterations, "应该生成%d个唯一的UUID", iterations)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewUUIDv7_Format(t *testing.T) {
|
||||
// 测试UUID v7的具体格式要求
|
||||
uuid := helpers.NewUUIDv7()
|
||||
|
||||
// 总长度 32
|
||||
assert.Len(t, uuid, 32)
|
||||
|
||||
// 全部小写
|
||||
for _, c := range uuid {
|
||||
if c >= 'a' && c <= 'f' {
|
||||
assert.True(t, c >= 'a' && c <= 'f')
|
||||
} else {
|
||||
assert.True(t, c >= '0' && c <= '9')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUUIDv7_EdgeCases(t *testing.T) {
|
||||
t.Run("快速连续生成", func(t *testing.T) {
|
||||
// 在极短时间内生成多个UUID
|
||||
uuids := make([]string, 10)
|
||||
for i := range 10 {
|
||||
uuids[i] = helpers.NewUUIDv7()
|
||||
}
|
||||
|
||||
// 所有UUID应该都有效且唯一
|
||||
seen := make(map[string]bool)
|
||||
for _, uuid := range uuids {
|
||||
assert.Len(t, uuid, 32)
|
||||
assert.False(t, seen[uuid], "UUID不应该重复")
|
||||
seen[uuid] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
20
internal/helpers/validate.go
Normal file
20
internal/helpers/validate.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals // 单例模式需要全局变量
|
||||
var (
|
||||
validate *validator.Validate
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func GetValidator() *validator.Validate {
|
||||
once.Do(func() {
|
||||
validate = validator.New()
|
||||
})
|
||||
return validate
|
||||
}
|
||||
186
internal/helpers/validate_test.go
Normal file
186
internal/helpers/validate_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package helpers_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
func TestGetValidator(t *testing.T) {
|
||||
t.Run("返回有效的validator实例", func(t *testing.T) {
|
||||
v := helpers.GetValidator()
|
||||
|
||||
require.NotNil(t, v, "Validator不应该为nil")
|
||||
})
|
||||
|
||||
t.Run("单例模式:多次调用返回同一个实例", func(t *testing.T) {
|
||||
v1 := helpers.GetValidator()
|
||||
v2 := helpers.GetValidator()
|
||||
v3 := helpers.GetValidator()
|
||||
|
||||
// 使用指针比较,确保是同一个实例
|
||||
assert.Same(t, v1, v2, "第一次和第二次调用应该返回同一个实例")
|
||||
assert.Same(t, v2, v3, "第二次和第三次调用应该返回同一个实例")
|
||||
assert.Same(t, v1, v3, "第一次和第三次调用应该返回同一个实例")
|
||||
})
|
||||
|
||||
t.Run("并发获取validator应该安全", func(t *testing.T) {
|
||||
const concurrency = 100
|
||||
validators := make([]interface{}, concurrency)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
for i := range concurrency {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
v := helpers.GetValidator()
|
||||
// 存储validator实例
|
||||
validators[idx] = v
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 验证所有goroutine获取的是同一个实例
|
||||
firstValidator := validators[0]
|
||||
for i := 1; i < concurrency; i++ {
|
||||
assert.Same(t, firstValidator, validators[i],
|
||||
"并发调用第%d次获取的validator应该与第一次相同", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validator可以正常工作", func(t *testing.T) {
|
||||
v := helpers.GetValidator()
|
||||
|
||||
// 测试一个简单的结构体验证
|
||||
type TestStruct struct {
|
||||
Name string `validate:"required,min=2,max=10"`
|
||||
Email string `validate:"required,email"`
|
||||
Age int `validate:"gte=0,lte=120"`
|
||||
}
|
||||
|
||||
// 有效的结构体
|
||||
validData := TestStruct{
|
||||
Name: "John",
|
||||
Email: "john@example.com",
|
||||
Age: 30,
|
||||
}
|
||||
err := v.Struct(validData)
|
||||
require.NoError(t, err, "有效的数据不应该产生验证错误")
|
||||
|
||||
// 无效的结构体 - 缺少必填字段
|
||||
invalidData1 := TestStruct{
|
||||
Name: "",
|
||||
Age: 30,
|
||||
}
|
||||
err = v.Struct(invalidData1)
|
||||
require.Error(t, err, "缺少必填字段应该产生验证错误")
|
||||
|
||||
// 无效的结构体 - 字段值超出范围
|
||||
invalidData2 := TestStruct{
|
||||
Name: "John",
|
||||
Email: "john@example.com",
|
||||
Age: 150,
|
||||
}
|
||||
err = v.Struct(invalidData2)
|
||||
require.Error(t, err, "年龄超出范围应该产生验证错误")
|
||||
|
||||
// 无效的结构体 - 邮箱格式错误
|
||||
invalidData3 := TestStruct{
|
||||
Name: "John",
|
||||
Email: "invalid-email",
|
||||
Age: 30,
|
||||
}
|
||||
err = v.Struct(invalidData3)
|
||||
assert.Error(t, err, "无效的邮箱格式应该产生验证错误")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetValidator_InitializationOnce(t *testing.T) {
|
||||
// 这个测试验证 sync.Once 确保初始化只执行一次
|
||||
const calls = 1000
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(calls)
|
||||
|
||||
results := make([]interface{}, calls)
|
||||
|
||||
for i := range calls {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
v := helpers.GetValidator()
|
||||
results[idx] = v
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 所有结果应该指向同一个实例
|
||||
first := results[0]
|
||||
for i := 1; i < calls; i++ {
|
||||
assert.Same(t, first, results[i],
|
||||
"所有调用应该返回完全相同的validator实例")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidator_ValidatorFunctionality(t *testing.T) {
|
||||
v := helpers.GetValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "结构体字段验证-成功",
|
||||
data: struct {
|
||||
Field string `validate:"required"`
|
||||
}{
|
||||
Field: "value",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "结构体字段验证-失败",
|
||||
data: struct {
|
||||
Field string `validate:"required"`
|
||||
}{
|
||||
Field: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "数字范围验证-成功",
|
||||
data: struct {
|
||||
Count int `validate:"min=1,max=100"`
|
||||
}{
|
||||
Count: 50,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数字范围验证-失败",
|
||||
data: struct {
|
||||
Count int `validate:"min=1,max=100"`
|
||||
}{
|
||||
Count: 200,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.Struct(tt.data)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
194
internal/logger/logger.go
Normal file
194
internal/logger/logger.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill"
|
||||
"github.com/apache/pulsar-client-go/pulsar/log"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// argsPerField 每个字段转换为args时的参数数量(key+value).
|
||||
argsPerField = 2
|
||||
)
|
||||
|
||||
type WatermillLoggerAdapter struct {
|
||||
logger logger.Logger
|
||||
fields watermill.LogFields
|
||||
}
|
||||
|
||||
func (w WatermillLoggerAdapter) Error(msg string, err error, fields watermill.LogFields) {
|
||||
allFields := mergeFields(w.fields, fields)
|
||||
args := allFieldsToArgs(allFields)
|
||||
w.logger.Error(fmt.Sprintf("%s: %v", msg, err), args...)
|
||||
}
|
||||
|
||||
func (w WatermillLoggerAdapter) Info(msg string, fields watermill.LogFields) {
|
||||
allFields := mergeFields(w.fields, fields)
|
||||
args := allFieldsToArgs(allFields)
|
||||
w.logger.Info(msg, args...)
|
||||
}
|
||||
|
||||
func (w WatermillLoggerAdapter) Debug(msg string, fields watermill.LogFields) {
|
||||
allFields := mergeFields(w.fields, fields)
|
||||
args := allFieldsToArgs(allFields)
|
||||
w.logger.Debug(msg, args...)
|
||||
}
|
||||
|
||||
func (w WatermillLoggerAdapter) Trace(msg string, fields watermill.LogFields) {
|
||||
allFields := mergeFields(w.fields, fields)
|
||||
args := allFieldsToArgs(allFields)
|
||||
w.logger.Debug(fmt.Sprintf("[TRACE] %s", msg), args...)
|
||||
}
|
||||
|
||||
func (w WatermillLoggerAdapter) With(fields watermill.LogFields) watermill.LoggerAdapter {
|
||||
newFields := mergeFields(w.fields, fields)
|
||||
return WatermillLoggerAdapter{
|
||||
logger: w.logger,
|
||||
fields: newFields,
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:funcorder // 构造函数放在此处更符合代码组织
|
||||
func NewWatermillLoggerAdapter(logger logger.Logger) *WatermillLoggerAdapter {
|
||||
return &WatermillLoggerAdapter{logger: logger, fields: watermill.LogFields{}}
|
||||
}
|
||||
|
||||
func mergeFields(base, extra watermill.LogFields) watermill.LogFields {
|
||||
merged := make(watermill.LogFields, len(base)+len(extra))
|
||||
for k, v := range base {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range extra {
|
||||
merged[k] = v
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func allFieldsToArgs(fields watermill.LogFields) []any {
|
||||
args := make([]any, 0, len(fields)*argsPerField)
|
||||
for k, v := range fields {
|
||||
args = append(args, k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// ================= PulsarLoggerAdapter ======================
|
||||
|
||||
type PulsarLoggerAdapter struct {
|
||||
logger logger.Logger
|
||||
fields log.Fields
|
||||
err error
|
||||
}
|
||||
|
||||
func NewPulsarLoggerAdapter(l logger.Logger) *PulsarLoggerAdapter {
|
||||
return &PulsarLoggerAdapter{logger: l, fields: log.Fields{}}
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) SubLogger(fields log.Fields) log.Logger {
|
||||
return PulsarLoggerAdapter{
|
||||
logger: p.logger,
|
||||
fields: mergePulsarFields(p.fields, fields),
|
||||
err: p.err,
|
||||
}
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) WithFields(fields log.Fields) log.Entry {
|
||||
return PulsarLoggerAdapter{
|
||||
logger: p.logger,
|
||||
fields: mergePulsarFields(p.fields, fields),
|
||||
err: p.err,
|
||||
}
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) WithField(name string, value interface{}) log.Entry {
|
||||
newFields := mergePulsarFields(p.fields, log.Fields{name: value})
|
||||
return PulsarLoggerAdapter{
|
||||
logger: p.logger,
|
||||
fields: newFields,
|
||||
err: p.err,
|
||||
}
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) WithError(err error) log.Entry {
|
||||
return PulsarLoggerAdapter{
|
||||
logger: p.logger,
|
||||
fields: p.fields,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Debug(args ...interface{}) {
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
p.logger.Debug(fmt.Sprint(args...), fieldsArgs...)
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Info(args ...interface{}) {
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
p.logger.Info(fmt.Sprint(args...), fieldsArgs...)
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Warn(args ...interface{}) {
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
p.logger.Warn(fmt.Sprint(args...), fieldsArgs...)
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Error(args ...interface{}) {
|
||||
msg := fmt.Sprint(args...)
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
if p.err != nil {
|
||||
// 将error作为key-value对添加到args中
|
||||
fieldsArgs = append(fieldsArgs, "error", p.err)
|
||||
p.logger.Error(msg, fieldsArgs...)
|
||||
} else {
|
||||
p.logger.Error(msg, fieldsArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Debugf(format string, args ...interface{}) {
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
p.logger.Debug(fmt.Sprintf(format, args...), fieldsArgs...)
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Infof(format string, args ...interface{}) {
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
p.logger.Info(fmt.Sprintf(format, args...), fieldsArgs...)
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Warnf(format string, args ...interface{}) {
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
p.logger.Warn(fmt.Sprintf(format, args...), fieldsArgs...)
|
||||
}
|
||||
|
||||
func (p PulsarLoggerAdapter) Errorf(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
fieldsArgs := fieldsToArgs(p.fields)
|
||||
if p.err != nil {
|
||||
p.logger.Error(fmt.Sprintf("%s: %v", msg, p.err), fieldsArgs...)
|
||||
} else {
|
||||
p.logger.Error(msg, fieldsArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
// 合并 Pulsar log.Fields.
|
||||
func mergePulsarFields(base, extra log.Fields) log.Fields {
|
||||
merged := make(log.Fields, len(base)+len(extra))
|
||||
for k, v := range base {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range extra {
|
||||
merged[k] = v
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// 将 Pulsar log.Fields 转为 args ...any 形式,适配 Adapter.
|
||||
func fieldsToArgs(fields log.Fields) []any {
|
||||
args := make([]any, 0, len(fields)*argsPerField)
|
||||
for k, v := range fields {
|
||||
args = append(args, k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
385
internal/logger/logger_test.go
Normal file
385
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
apilogger "go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/logger"
|
||||
)
|
||||
|
||||
func TestNewWatermillLoggerAdapter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
assert.NotNil(t, adapter)
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
err := errors.New("test error")
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Error("error message", err, fields)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Info("info message", fields)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_Debug(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Debug("debug message", fields)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_Trace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Trace("trace message", fields)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_With(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
fields1 := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
}
|
||||
fields2 := map[string]interface{}{
|
||||
"key2": "value2",
|
||||
}
|
||||
|
||||
newAdapter := adapter.With(fields1)
|
||||
assert.NotNil(t, newAdapter)
|
||||
|
||||
// Test that fields are merged
|
||||
newAdapter2 := newAdapter.With(fields2)
|
||||
assert.NotNil(t, newAdapter2)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
newAdapter2.Info("test", map[string]interface{}{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewPulsarLoggerAdapter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
assert.NotNil(t, adapter)
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Debug(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Debug("debug message")
|
||||
adapter.Debug("debug", "message", "with", "args")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Info("info message")
|
||||
adapter.Info("info", "message", "with", "args")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Warn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Warn("warn message")
|
||||
adapter.Warn("warn", "message", "with", "args")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
test func()
|
||||
}{
|
||||
{
|
||||
name: "without error",
|
||||
test: func() {
|
||||
adapter.Error("error message")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
test: func() {
|
||||
adapterWithErr := adapter.WithError(errors.New("test error"))
|
||||
adapterWithErr.Error("error message")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.NotPanics(t, tt.test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Debugf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Debugf("debug %s", "message")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Infof(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Infof("info %s", "message")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Warnf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Warnf("warn %s", "message")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_Errorf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
test func()
|
||||
}{
|
||||
{
|
||||
name: "without error",
|
||||
test: func() {
|
||||
adapter.Errorf("error %s", "message")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
test: func() {
|
||||
adapterWithErr := adapter.WithError(errors.New("test error"))
|
||||
adapterWithErr.Errorf("error %s", "message")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.NotPanics(t, tt.test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_SubLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
}
|
||||
|
||||
subLogger := adapter.SubLogger(fields)
|
||||
assert.NotNil(t, subLogger)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
subLogger.Info("test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_WithFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
entry := adapter.WithFields(fields)
|
||||
assert.NotNil(t, entry)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
entry.Info("test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_WithField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
entry := adapter.WithField("key", "value")
|
||||
assert.NotNil(t, entry)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
entry.Info("test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_WithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
err := errors.New("test error")
|
||||
entry := adapter.WithError(err)
|
||||
assert.NotNil(t, entry)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
entry.Error("test error message")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_ChainedFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
entry1 := adapter.WithField("key1", "value1")
|
||||
entry2 := entry1.WithField("key2", "value2")
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
entry2.Info("chained fields test")
|
||||
})
|
||||
|
||||
// Test WithError separately
|
||||
entryWithErr := adapter.WithError(errors.New("test error"))
|
||||
assert.NotPanics(t, func() {
|
||||
entryWithErr.Error("chained fields test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPulsarLoggerAdapter_FormatMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewPulsarLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Debugf("debug %d", 1)
|
||||
adapter.Infof("info %d", 2)
|
||||
adapter.Warnf("warn %d", 3)
|
||||
adapter.Errorf("error %d", 4)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_EmptyFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapter.Error("error", errors.New("test"), map[string]interface{}{})
|
||||
adapter.Info("info", map[string]interface{}{})
|
||||
adapter.Debug("debug", map[string]interface{}{})
|
||||
adapter.Trace("trace", map[string]interface{}{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWatermillLoggerAdapter_MergedFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := apilogger.NewNopLogger()
|
||||
adapter := logger.NewWatermillLoggerAdapter(l)
|
||||
|
||||
baseFields := map[string]interface{}{
|
||||
"base": "value",
|
||||
}
|
||||
extraFields := map[string]interface{}{
|
||||
"extra": "value",
|
||||
}
|
||||
|
||||
adapterWithBase := adapter.With(baseFields)
|
||||
adapterWithBoth := adapterWithBase.With(extraFields)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
adapterWithBoth.Info("test", map[string]interface{}{})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user