refactor: 重构trustlog-sdk目录结构到trustlog/go-trustlog
- 将所有trustlog-sdk文件移动到trustlog/go-trustlog/目录 - 更新README中所有import路径从trustlog-sdk改为go-trustlog - 更新cookiecutter配置文件中的项目名称 - 更新根目录.lefthook.yml以引用新位置的配置 - 添加go.sum文件到版本控制 - 删除过时的示例文件 这次重构与trustlog-server保持一致的目录结构, 为未来支持多语言SDK(Python、Java等)预留空间。
This commit is contained in:
207
api/model/config_signer.go
Normal file
207
api/model/config_signer.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/crpt/go-crpt"
|
||||
_ "github.com/crpt/go-crpt/ed25519" // 注册 Ed25519
|
||||
_ "github.com/crpt/go-crpt/sm2" // 注册 SM2
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// ConfigSigner 基于配置的通用签名器
|
||||
// 根据 CryptoConfig 自动使用对应的签名算法
|
||||
type ConfigSigner struct {
|
||||
privateKey []byte // 私钥(序列化格式)
|
||||
publicKey []byte // 公钥(序列化格式)
|
||||
config *CryptoConfig // 密码学配置
|
||||
privKey crpt.PrivateKey // 解析后的私钥
|
||||
pubKey crpt.PublicKey // 解析后的公钥
|
||||
}
|
||||
|
||||
// NewConfigSigner 创建基于配置的签名器
|
||||
// 如果 config 为 nil,则使用全局配置
|
||||
func NewConfigSigner(privateKey, publicKey []byte, config *CryptoConfig) (*ConfigSigner, error) {
|
||||
if config == nil {
|
||||
config = GetGlobalCryptoConfig()
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating ConfigSigner",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
"privateKeyLength", len(privateKey),
|
||||
"publicKeyLength", len(publicKey),
|
||||
)
|
||||
|
||||
signer := &ConfigSigner{
|
||||
privateKey: privateKey,
|
||||
publicKey: publicKey,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// 延迟解析密钥,只在需要时解析
|
||||
// 这样可以避免初始化顺序问题
|
||||
|
||||
log.Debug("ConfigSigner created successfully",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
)
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
// NewDefaultSigner 创建使用默认 SM2 算法的签名器
|
||||
// 注意:总是使用 SM2,不受全局配置影响
|
||||
func NewDefaultSigner(privateKey, publicKey []byte) (*ConfigSigner, error) {
|
||||
sm2Config := &CryptoConfig{
|
||||
SignatureAlgorithm: SM2Algorithm,
|
||||
}
|
||||
return NewConfigSigner(privateKey, publicKey, sm2Config)
|
||||
}
|
||||
|
||||
// Sign 对数据进行签名
|
||||
func (s *ConfigSigner) Sign(data []byte) ([]byte, error) {
|
||||
if len(s.privateKey) == 0 {
|
||||
return nil, fmt.Errorf("private key is not set")
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Signing with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"dataLength", len(data),
|
||||
)
|
||||
|
||||
// 根据算法类型使用对应的方法
|
||||
switch s.config.SignatureAlgorithm {
|
||||
case SM2Algorithm:
|
||||
// SM2 使用现有的 ComputeSignature 函数(兼容 DER 格式)
|
||||
signature, err := ComputeSignature(data, s.privateKey)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign with SM2",
|
||||
"error", err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("Signed successfully with SM2",
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
|
||||
default:
|
||||
// 其他算法使用 crpt 通用接口
|
||||
// 懒加载:解析私钥
|
||||
if s.privKey == nil {
|
||||
keyType, err := s.config.SignatureAlgorithm.toKeyType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
privKey, err := crpt.PrivateKeyFromBytes(keyType, s.privateKey)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse private key",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
s.privKey = privKey
|
||||
}
|
||||
|
||||
signature, err := crpt.SignMessage(s.privKey, data, nil, nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to sign: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Signed successfully with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Verify 验证签名
|
||||
func (s *ConfigSigner) Verify(data, signature []byte) (bool, error) {
|
||||
if len(s.publicKey) == 0 {
|
||||
return false, fmt.Errorf("public key is not set")
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
|
||||
// 根据算法类型使用对应的方法
|
||||
switch s.config.SignatureAlgorithm {
|
||||
case SM2Algorithm:
|
||||
// SM2 使用现有的 VerifySignature 函数(兼容 DER 格式)
|
||||
ok, err := VerifySignature(data, s.publicKey, signature)
|
||||
if err != nil {
|
||||
// VerifySignature 在验证失败时也返回错误,需要判断错误类型
|
||||
// 如果是 "signature verification failed",则返回 (false, nil)
|
||||
if ok == false {
|
||||
// 验证失败(不是异常)
|
||||
log.Warn("Verification failed with SM2")
|
||||
return false, nil
|
||||
}
|
||||
// 其他错误(如解析错误)
|
||||
log.Error("Failed to verify with SM2", "error", err)
|
||||
return false, err
|
||||
}
|
||||
log.Debug("Verified successfully with SM2")
|
||||
return true, nil
|
||||
|
||||
default:
|
||||
// 其他算法使用 crpt 通用接口
|
||||
// 懒加载:解析公钥
|
||||
if s.pubKey == nil {
|
||||
keyType, err := s.config.SignatureAlgorithm.toKeyType()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
pubKey, err := crpt.PublicKeyFromBytes(keyType, s.publicKey)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse public key",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"error", err,
|
||||
)
|
||||
return false, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
s.pubKey = pubKey
|
||||
}
|
||||
|
||||
ok, err := crpt.VerifyMessage(s.pubKey, data, crpt.Signature(signature), nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to verify with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
"error", err,
|
||||
)
|
||||
return false, fmt.Errorf("failed to verify: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debug("Verified successfully with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
)
|
||||
} else {
|
||||
log.Warn("Verification failed with ConfigSigner",
|
||||
"algorithm", s.config.SignatureAlgorithm,
|
||||
)
|
||||
}
|
||||
|
||||
return ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetAlgorithm 获取签名器使用的算法
|
||||
func (s *ConfigSigner) GetAlgorithm() SignatureAlgorithm {
|
||||
return s.config.SignatureAlgorithm
|
||||
}
|
||||
158
api/model/config_signer_test.go
Normal file
158
api/model/config_signer_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/crpt/go-crpt/sm2" // 确保 SM2 已注册
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestNewConfigSigner_SM2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成 SM2 密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名器
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.SM2Algorithm,
|
||||
}
|
||||
signer, err := model.NewConfigSigner(privateKeyDER, publicKeyDER, config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, signer)
|
||||
assert.Equal(t, model.SM2Algorithm, signer.GetAlgorithm())
|
||||
}
|
||||
|
||||
func TestNewDefaultSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成 SM2 密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建默认签名器(应该使用 SM2)
|
||||
signer, err := model.NewDefaultSigner(privateKeyDER, publicKeyDER)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, signer)
|
||||
assert.Equal(t, model.SM2Algorithm, signer.GetAlgorithm())
|
||||
}
|
||||
|
||||
func TestConfigSigner_SignAndVerify_SM2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名器
|
||||
signer, err := model.NewDefaultSigner(privateKeyDER, publicKeyDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 签名
|
||||
data := []byte("test data for ConfigSigner")
|
||||
signature, err := signer.Sign(data)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// 验证
|
||||
ok, err := signer.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证错误数据
|
||||
wrongData := []byte("wrong data")
|
||||
ok, err = signer.Verify(wrongData, signature)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestConfigSigner_SignAndVerify_Ed25519(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成 Ed25519 密钥对
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||
}
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKeyDER, err := keyPair.MarshalPrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := keyPair.MarshalPublicKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名器
|
||||
signer, err := model.NewConfigSigner(privateKeyDER, publicKeyDER, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 签名
|
||||
data := []byte("test data for Ed25519")
|
||||
signature, err := signer.Sign(data)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// 验证
|
||||
ok, err := signer.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestConfigSigner_CompatibleWithSM2Signer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用 ConfigSigner 签名
|
||||
configSigner, err := model.NewDefaultSigner(privateKeyDER, publicKeyDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("test data")
|
||||
signature1, err := configSigner.Sign(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用 SM2Signer 验证
|
||||
sm2Signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||
ok1, err := sm2Signer.Verify(data, signature1)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok1, "SM2Signer should verify ConfigSigner's signature")
|
||||
|
||||
// 使用 SM2Signer 签名
|
||||
signature2, err := sm2Signer.Sign(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用 ConfigSigner 验证
|
||||
ok2, err := configSigner.Verify(data, signature2)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok2, "ConfigSigner should verify SM2Signer's signature")
|
||||
}
|
||||
200
api/model/converter.go
Normal file
200
api/model/converter.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||
)
|
||||
|
||||
// FromProtobuf 将protobuf的OperationData转换为model.Operation.
|
||||
func FromProtobuf(pbOp *pb.OperationData) (*Operation, error) {
|
||||
if pbOp == nil {
|
||||
return nil, errors.New("protobuf operation data is nil")
|
||||
}
|
||||
|
||||
// 转换时间戳
|
||||
if pbOp.GetTimestamp() == nil {
|
||||
return nil, errors.New("timestamp is required")
|
||||
}
|
||||
timestamp := pbOp.GetTimestamp().AsTime()
|
||||
|
||||
// 构建Operation
|
||||
operation := &Operation{
|
||||
OpID: pbOp.GetOpId(),
|
||||
Timestamp: timestamp,
|
||||
OpSource: Source(pbOp.GetOpSource()),
|
||||
OpType: Type(pbOp.GetOpType()),
|
||||
DoPrefix: pbOp.GetDoPrefix(),
|
||||
DoRepository: pbOp.GetDoRepository(),
|
||||
Doid: pbOp.GetDoid(),
|
||||
ProducerID: pbOp.GetProducerId(),
|
||||
OpActor: pbOp.GetOpActor(),
|
||||
// OpAlgorithm和OpMetaHash字段已移除,固定使用Sha256Simd,哈希值由Envelope的OriginalHash提供
|
||||
}
|
||||
|
||||
// 处理可选的哈希字段
|
||||
if reqHash := pbOp.GetRequestBodyHash(); reqHash != "" {
|
||||
operation.RequestBodyHash = &reqHash
|
||||
}
|
||||
if respHash := pbOp.GetResponseBodyHash(); respHash != "" {
|
||||
operation.ResponseBodyHash = &respHash
|
||||
}
|
||||
|
||||
return operation, nil
|
||||
}
|
||||
|
||||
// ToProtobuf 将model.Operation转换为protobuf的OperationData.
|
||||
func ToProtobuf(op *Operation) (*pb.OperationData, error) {
|
||||
if op == nil {
|
||||
return nil, errors.New("operation is nil")
|
||||
}
|
||||
|
||||
// 转换时间戳
|
||||
timestamp := timestamppb.New(op.Timestamp)
|
||||
|
||||
pbOp := &pb.OperationData{
|
||||
OpId: op.OpID,
|
||||
Timestamp: timestamp,
|
||||
OpSource: string(op.OpSource),
|
||||
OpType: string(op.OpType),
|
||||
DoPrefix: op.DoPrefix,
|
||||
DoRepository: op.DoRepository,
|
||||
Doid: op.Doid,
|
||||
ProducerId: op.ProducerID,
|
||||
OpActor: op.OpActor,
|
||||
// OpAlgorithm、OpMetaHash和OpHash字段已移除,固定使用Sha256Simd,哈希值由Envelope的OriginalHash提供
|
||||
}
|
||||
|
||||
// 处理可选的哈希字段
|
||||
if op.RequestBodyHash != nil {
|
||||
pbOp.RequestBodyHash = *op.RequestBodyHash
|
||||
}
|
||||
if op.ResponseBodyHash != nil {
|
||||
pbOp.ResponseBodyHash = *op.ResponseBodyHash
|
||||
}
|
||||
|
||||
return pbOp, nil
|
||||
}
|
||||
|
||||
// FromProtobufValidationResult 将protobuf的ValidationStreamRes转换为model.ValidationResult.
|
||||
func FromProtobufValidationResult(pbRes *pb.ValidationStreamRes) (*ValidationResult, error) {
|
||||
if pbRes == nil {
|
||||
return nil, errors.New("protobuf validation result is nil")
|
||||
}
|
||||
|
||||
result := &ValidationResult{
|
||||
Code: pbRes.GetCode(),
|
||||
Msg: pbRes.GetMsg(),
|
||||
Progress: pbRes.GetProgress(),
|
||||
Proof: ProofFromProtobuf(pbRes.GetProof()), // 取证证明
|
||||
}
|
||||
|
||||
// 如果有操作数据,则转换
|
||||
if pbRes.GetData() != nil {
|
||||
op, err := FromProtobuf(pbRes.GetData())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert operation data: %w", err)
|
||||
}
|
||||
result.Data = op
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RecordFromProtobuf 将protobuf的RecordData转换为model.Record.
|
||||
func RecordFromProtobuf(pbRec *pb.RecordData) (*Record, error) {
|
||||
if pbRec == nil {
|
||||
return nil, errors.New("protobuf record data is nil")
|
||||
}
|
||||
|
||||
// 构建Record
|
||||
record := &Record{
|
||||
ID: pbRec.GetId(),
|
||||
DoPrefix: pbRec.GetDoPrefix(),
|
||||
ProducerID: pbRec.GetProducerId(),
|
||||
Operator: pbRec.GetOperator(),
|
||||
Extra: pbRec.GetExtra(),
|
||||
RCType: pbRec.GetRcType(),
|
||||
}
|
||||
|
||||
// 转换时间戳
|
||||
if pbRec.GetTimestamp() != nil {
|
||||
record.Timestamp = pbRec.GetTimestamp().AsTime()
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RecordToProtobuf 将model.Record转换为protobuf的RecordData.
|
||||
func RecordToProtobuf(rec *Record) (*pb.RecordData, error) {
|
||||
if rec == nil {
|
||||
return nil, errors.New("record is nil")
|
||||
}
|
||||
|
||||
// 转换时间戳
|
||||
timestamp := timestamppb.New(rec.Timestamp)
|
||||
|
||||
pbRec := &pb.RecordData{
|
||||
Id: rec.ID,
|
||||
DoPrefix: rec.DoPrefix,
|
||||
ProducerId: rec.ProducerID,
|
||||
Timestamp: timestamp,
|
||||
Operator: rec.Operator,
|
||||
Extra: rec.Extra,
|
||||
RcType: rec.RCType,
|
||||
}
|
||||
|
||||
return pbRec, nil
|
||||
}
|
||||
|
||||
// RecordValidationResult 包装记录验证的流式响应结果.
|
||||
type RecordValidationResult struct {
|
||||
Code int32 // 状态码(100处理中,200完成,500失败)
|
||||
Msg string // 消息描述
|
||||
Progress string // 当前进度(比如 "50%")
|
||||
Data *Record // 最终完成时返回的记录数据,过程中可为空
|
||||
Proof *Proof // 取证证明(仅在完成时返回)
|
||||
}
|
||||
|
||||
// IsProcessing 判断是否正在处理中.
|
||||
func (r *RecordValidationResult) IsProcessing() bool {
|
||||
return r.Code == ValidationCodeProcessing
|
||||
}
|
||||
|
||||
// IsCompleted 判断是否已完成.
|
||||
func (r *RecordValidationResult) IsCompleted() bool {
|
||||
return r.Code == ValidationCodeCompleted
|
||||
}
|
||||
|
||||
// IsFailed 判断是否失败.
|
||||
func (r *RecordValidationResult) IsFailed() bool {
|
||||
return r.Code >= ValidationCodeFailed
|
||||
}
|
||||
|
||||
// RecordFromProtobufValidationResult 将protobuf的RecordValidationStreamRes转换为model.RecordValidationResult.
|
||||
func RecordFromProtobufValidationResult(pbRes *pb.RecordValidationStreamRes) (*RecordValidationResult, error) {
|
||||
if pbRes == nil {
|
||||
return nil, errors.New("protobuf record validation result is nil")
|
||||
}
|
||||
|
||||
result := &RecordValidationResult{
|
||||
Code: pbRes.GetCode(),
|
||||
Msg: pbRes.GetMsg(),
|
||||
Progress: pbRes.GetProgress(),
|
||||
Proof: ProofFromProtobuf(pbRes.GetProof()), // 取证证明
|
||||
}
|
||||
|
||||
// 如果有记录数据,则转换
|
||||
if pbRes.GetResult() != nil {
|
||||
rec, err := RecordFromProtobuf(pbRes.GetResult())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert record data: %w", err)
|
||||
}
|
||||
result.Data = rec
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
575
api/model/converter_test.go
Normal file
575
api/model/converter_test.go
Normal file
@@ -0,0 +1,575 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestFromProtobuf_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := model.FromProtobuf(nil)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "protobuf operation data is nil")
|
||||
}
|
||||
|
||||
func TestFromProtobuf_NoTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbOp := &pb.OperationData{}
|
||||
result, err := model.FromProtobuf(pbOp)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timestamp is required")
|
||||
}
|
||||
|
||||
func TestFromProtobuf_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbOp := &pb.OperationData{
|
||||
OpId: "op-123",
|
||||
Timestamp: timestamppb.New(now),
|
||||
OpSource: "IRP",
|
||||
OpType: "OC_CREATE_HANDLE",
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerId: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
result, err := model.FromProtobuf(pbOp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, "op-123", result.OpID)
|
||||
assert.Equal(t, now.Unix(), result.Timestamp.Unix())
|
||||
assert.Equal(t, model.Source("IRP"), result.OpSource)
|
||||
assert.Equal(t, model.Type("OC_CREATE_HANDLE"), result.OpType)
|
||||
assert.Equal(t, "test", result.DoPrefix)
|
||||
assert.Equal(t, "repo", result.DoRepository)
|
||||
assert.Equal(t, "test/repo/123", result.Doid)
|
||||
assert.Equal(t, "producer-1", result.ProducerID)
|
||||
assert.Equal(t, "actor-1", result.OpActor)
|
||||
}
|
||||
|
||||
func TestFromProtobuf_WithHashes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbOp := &pb.OperationData{
|
||||
OpId: "op-123",
|
||||
Timestamp: timestamppb.New(now),
|
||||
OpSource: "DOIP",
|
||||
OpType: "Create",
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerId: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
RequestBodyHash: "req-hash",
|
||||
ResponseBodyHash: "resp-hash",
|
||||
}
|
||||
|
||||
result, err := model.FromProtobuf(pbOp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.NotNil(t, result.RequestBodyHash)
|
||||
assert.Equal(t, "req-hash", *result.RequestBodyHash)
|
||||
assert.NotNil(t, result.ResponseBodyHash)
|
||||
assert.Equal(t, "resp-hash", *result.ResponseBodyHash)
|
||||
}
|
||||
|
||||
func TestFromProtobuf_EmptyHashes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbOp := &pb.OperationData{
|
||||
OpId: "op-123",
|
||||
Timestamp: timestamppb.New(now),
|
||||
OpSource: "DOIP",
|
||||
OpType: "Create",
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerId: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
RequestBodyHash: "",
|
||||
ResponseBodyHash: "",
|
||||
}
|
||||
|
||||
result, err := model.FromProtobuf(pbOp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Nil(t, result.RequestBodyHash)
|
||||
assert.Nil(t, result.ResponseBodyHash)
|
||||
}
|
||||
|
||||
func TestToProtobuf_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := model.ToProtobuf(nil)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "operation is nil")
|
||||
}
|
||||
|
||||
func TestToProtobuf_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: now,
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
result, err := model.ToProtobuf(op)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, "op-123", result.GetOpId())
|
||||
assert.Equal(t, now.Unix(), result.GetTimestamp().AsTime().Unix())
|
||||
assert.Equal(t, "IRP", result.GetOpSource())
|
||||
assert.Equal(t, "OC_CREATE_HANDLE", result.GetOpType())
|
||||
assert.Equal(t, "test", result.GetDoPrefix())
|
||||
assert.Equal(t, "repo", result.GetDoRepository())
|
||||
assert.Equal(t, "test/repo/123", result.GetDoid())
|
||||
assert.Equal(t, "producer-1", result.GetProducerId())
|
||||
assert.Equal(t, "actor-1", result.GetOpActor())
|
||||
}
|
||||
|
||||
func TestToProtobuf_WithHashes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reqHash := "req-hash"
|
||||
respHash := "resp-hash"
|
||||
now := time.Now()
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: now,
|
||||
OpSource: model.OpSourceDOIP,
|
||||
OpType: model.OpTypeCreate,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
RequestBodyHash: &reqHash,
|
||||
ResponseBodyHash: &respHash,
|
||||
}
|
||||
|
||||
result, err := model.ToProtobuf(op)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, "req-hash", result.GetRequestBodyHash())
|
||||
assert.Equal(t, "resp-hash", result.GetResponseBodyHash())
|
||||
}
|
||||
|
||||
func TestToProtobuf_WithoutHashes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: now,
|
||||
OpSource: model.OpSourceDOIP,
|
||||
OpType: model.OpTypeCreate,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
result, err := model.ToProtobuf(op)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Empty(t, result.GetRequestBodyHash())
|
||||
assert.Empty(t, result.GetResponseBodyHash())
|
||||
}
|
||||
|
||||
func TestFromProtobufValidationResult_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := model.FromProtobufValidationResult(nil)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "protobuf validation result is nil")
|
||||
}
|
||||
|
||||
func TestFromProtobufValidationResult_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRes := &pb.ValidationStreamRes{
|
||||
Code: 100,
|
||||
Msg: "Processing",
|
||||
Progress: "50%",
|
||||
}
|
||||
|
||||
result, err := model.FromProtobufValidationResult(pbRes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, int32(100), result.Code)
|
||||
assert.Equal(t, "Processing", result.Msg)
|
||||
assert.Equal(t, "50%", result.Progress)
|
||||
assert.Nil(t, result.Data)
|
||||
assert.Nil(t, result.Proof)
|
||||
}
|
||||
|
||||
func TestFromProtobufValidationResult_WithProof(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRes := &pb.ValidationStreamRes{
|
||||
Code: 200,
|
||||
Msg: "Completed",
|
||||
Progress: "100%",
|
||||
Proof: &pb.Proof{
|
||||
Sign: "test-signature",
|
||||
ColItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "hash1", Left: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := model.FromProtobufValidationResult(pbRes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, int32(200), result.Code)
|
||||
assert.NotNil(t, result.Proof)
|
||||
assert.Equal(t, "test-signature", result.Proof.Sign)
|
||||
assert.Len(t, result.Proof.ColItems, 1)
|
||||
}
|
||||
|
||||
func TestFromProtobufValidationResult_WithData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbRes := &pb.ValidationStreamRes{
|
||||
Code: 200,
|
||||
Msg: "Completed",
|
||||
Progress: "100%",
|
||||
Data: &pb.OperationData{
|
||||
OpId: "op-123",
|
||||
Timestamp: timestamppb.New(now),
|
||||
OpSource: "IRP",
|
||||
OpType: "OC_CREATE_HANDLE",
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerId: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := model.FromProtobufValidationResult(pbRes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, int32(200), result.Code)
|
||||
assert.NotNil(t, result.Data)
|
||||
assert.Equal(t, "op-123", result.Data.OpID)
|
||||
}
|
||||
|
||||
func TestFromProtobufValidationResult_WithInvalidData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRes := &pb.ValidationStreamRes{
|
||||
Code: 200,
|
||||
Msg: "Completed",
|
||||
Progress: "100%",
|
||||
Data: &pb.OperationData{
|
||||
// Missing timestamp
|
||||
},
|
||||
}
|
||||
|
||||
result, err := model.FromProtobufValidationResult(pbRes)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to convert operation data")
|
||||
}
|
||||
|
||||
func TestRecordFromProtobuf_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := model.RecordFromProtobuf(nil)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "protobuf record data is nil")
|
||||
}
|
||||
|
||||
func TestRecordFromProtobuf_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbRec := &pb.RecordData{
|
||||
Id: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerId: "producer-1",
|
||||
Timestamp: timestamppb.New(now),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra-data"),
|
||||
RcType: "log",
|
||||
}
|
||||
|
||||
result, err := model.RecordFromProtobuf(pbRec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, "rec-123", result.ID)
|
||||
assert.Equal(t, "test", result.DoPrefix)
|
||||
assert.Equal(t, "producer-1", result.ProducerID)
|
||||
assert.Equal(t, now.Unix(), result.Timestamp.Unix())
|
||||
assert.Equal(t, "operator-1", result.Operator)
|
||||
assert.Equal(t, []byte("extra-data"), result.Extra)
|
||||
assert.Equal(t, "log", result.RCType)
|
||||
}
|
||||
|
||||
func TestRecordFromProtobuf_NoTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRec := &pb.RecordData{
|
||||
Id: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerId: "producer-1",
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra-data"),
|
||||
RcType: "log",
|
||||
}
|
||||
|
||||
result, err := model.RecordFromProtobuf(pbRec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, "rec-123", result.ID)
|
||||
assert.True(t, result.Timestamp.IsZero())
|
||||
}
|
||||
|
||||
func TestRecordToProtobuf_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := model.RecordToProtobuf(nil)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "record is nil")
|
||||
}
|
||||
|
||||
func TestRecordToProtobuf_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
rec := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: now,
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra-data"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
result, err := model.RecordToProtobuf(rec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, "rec-123", result.GetId())
|
||||
assert.Equal(t, "test", result.GetDoPrefix())
|
||||
assert.Equal(t, "producer-1", result.GetProducerId())
|
||||
assert.Equal(t, now.Unix(), result.GetTimestamp().AsTime().Unix())
|
||||
assert.Equal(t, "operator-1", result.GetOperator())
|
||||
assert.Equal(t, []byte("extra-data"), result.GetExtra())
|
||||
assert.Equal(t, "log", result.GetRcType())
|
||||
}
|
||||
|
||||
func TestRecordFromProtobufValidationResult_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := model.RecordFromProtobufValidationResult(nil)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "protobuf record validation result is nil")
|
||||
}
|
||||
|
||||
func TestRecordFromProtobufValidationResult_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRes := &pb.RecordValidationStreamRes{
|
||||
Code: 100,
|
||||
Msg: "Processing",
|
||||
Progress: "50%",
|
||||
}
|
||||
|
||||
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, int32(100), result.Code)
|
||||
assert.Equal(t, "Processing", result.Msg)
|
||||
assert.Equal(t, "50%", result.Progress)
|
||||
assert.Nil(t, result.Data)
|
||||
assert.Nil(t, result.Proof)
|
||||
}
|
||||
|
||||
func TestRecordFromProtobufValidationResult_WithProof(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRes := &pb.RecordValidationStreamRes{
|
||||
Code: 200,
|
||||
Msg: "Completed",
|
||||
Progress: "100%",
|
||||
Proof: &pb.Proof{
|
||||
Sign: "test-signature",
|
||||
RawItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "hash1", Left: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, int32(200), result.Code)
|
||||
assert.NotNil(t, result.Proof)
|
||||
assert.Equal(t, "test-signature", result.Proof.Sign)
|
||||
assert.Len(t, result.Proof.RawItems, 1)
|
||||
}
|
||||
|
||||
func TestRecordFromProtobufValidationResult_WithData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbRes := &pb.RecordValidationStreamRes{
|
||||
Code: 200,
|
||||
Msg: "Completed",
|
||||
Progress: "100%",
|
||||
Result: &pb.RecordData{
|
||||
Id: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerId: "producer-1",
|
||||
Timestamp: timestamppb.New(now),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra-data"),
|
||||
RcType: "log",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, int32(200), result.Code)
|
||||
assert.NotNil(t, result.Data)
|
||||
assert.Equal(t, "rec-123", result.Data.ID)
|
||||
}
|
||||
|
||||
func TestRecordFromProtobufValidationResult_WithInvalidData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbRes := &pb.RecordValidationStreamRes{
|
||||
Code: 200,
|
||||
Msg: "Completed",
|
||||
Progress: "100%",
|
||||
Result: &pb.RecordData{
|
||||
// Missing required fields to trigger error
|
||||
},
|
||||
}
|
||||
|
||||
result, err := model.RecordFromProtobufValidationResult(pbRes)
|
||||
// This should succeed even with empty RecordData
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, int32(200), result.Code)
|
||||
}
|
||||
|
||||
func TestRoundTrip_Operation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
original := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: now,
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
// Convert to protobuf
|
||||
pbOp, err := model.ToProtobuf(original)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pbOp)
|
||||
|
||||
// Convert back to model
|
||||
result, err := model.FromProtobuf(pbOp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify round trip
|
||||
assert.Equal(t, original.OpID, result.OpID)
|
||||
assert.Equal(t, original.OpSource, result.OpSource)
|
||||
assert.Equal(t, original.OpType, result.OpType)
|
||||
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||
assert.Equal(t, original.DoRepository, result.DoRepository)
|
||||
assert.Equal(t, original.Doid, result.Doid)
|
||||
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||
assert.Equal(t, original.OpActor, result.OpActor)
|
||||
}
|
||||
|
||||
func TestRoundTrip_Record(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
original := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: now,
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra-data"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
// Convert to protobuf
|
||||
pbRec, err := model.RecordToProtobuf(original)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pbRec)
|
||||
|
||||
// Convert back to model
|
||||
result, err := model.RecordFromProtobuf(pbRec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify round trip
|
||||
assert.Equal(t, original.ID, result.ID)
|
||||
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||
assert.Equal(t, original.Timestamp.Unix(), result.Timestamp.Unix())
|
||||
assert.Equal(t, original.Operator, result.Operator)
|
||||
assert.Equal(t, original.Extra, result.Extra)
|
||||
assert.Equal(t, original.RCType, result.RCType)
|
||||
}
|
||||
310
api/model/crypto_config.go
Normal file
310
api/model/crypto_config.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/crpt/go-crpt"
|
||||
_ "github.com/crpt/go-crpt/ed25519" // Import Ed25519
|
||||
_ "github.com/crpt/go-crpt/sm2" // Import SM2
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// SignatureAlgorithm 定义支持的签名算法类型.
|
||||
type SignatureAlgorithm string
|
||||
|
||||
const (
|
||||
// SM2 国密SM2算法
|
||||
SM2Algorithm SignatureAlgorithm = "sm2"
|
||||
// Ed25519 Ed25519算法
|
||||
Ed25519Algorithm SignatureAlgorithm = "ed25519"
|
||||
)
|
||||
|
||||
// CryptoConfig 密码学配置
|
||||
type CryptoConfig struct {
|
||||
// SignatureAlgorithm 签名算法类型
|
||||
// SM2 会自动使用 SM3 哈希,Ed25519 会使用 SHA512 哈希
|
||||
SignatureAlgorithm SignatureAlgorithm
|
||||
}
|
||||
|
||||
var (
|
||||
// 默认配置:使用 SM2(内部自动使用 SM3)
|
||||
defaultConfig = &CryptoConfig{
|
||||
SignatureAlgorithm: SM2Algorithm,
|
||||
}
|
||||
|
||||
// 全局配置
|
||||
globalConfig *CryptoConfig
|
||||
globalConfigMutex sync.RWMutex
|
||||
|
||||
// ErrUnsupportedAlgorithm 不支持的算法错误
|
||||
ErrUnsupportedAlgorithm = errors.New("unsupported signature algorithm")
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 自动初始化全局配置为 SM2
|
||||
globalConfig = defaultConfig
|
||||
logger.GetGlobalLogger().Debug("Crypto config initialized with default SM2")
|
||||
}
|
||||
|
||||
// SetGlobalCryptoConfig 设置全局密码学配置
|
||||
func SetGlobalCryptoConfig(config *CryptoConfig) error {
|
||||
if config == nil {
|
||||
return errors.New("config cannot be nil")
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if err := config.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
globalConfigMutex.Lock()
|
||||
defer globalConfigMutex.Unlock()
|
||||
|
||||
globalConfig = config
|
||||
logger.GetGlobalLogger().Info("Global crypto config updated",
|
||||
"signatureAlgorithm", config.SignatureAlgorithm,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGlobalCryptoConfig 获取全局密码学配置
|
||||
func GetGlobalCryptoConfig() *CryptoConfig {
|
||||
globalConfigMutex.RLock()
|
||||
defer globalConfigMutex.RUnlock()
|
||||
|
||||
if globalConfig == nil {
|
||||
return defaultConfig
|
||||
}
|
||||
return globalConfig
|
||||
}
|
||||
|
||||
// Validate 验证配置是否有效
|
||||
func (c *CryptoConfig) Validate() error {
|
||||
// 验证签名算法
|
||||
switch c.SignatureAlgorithm {
|
||||
case SM2Algorithm, Ed25519Algorithm:
|
||||
// 支持的算法
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, c.SignatureAlgorithm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// toKeyType 将 SignatureAlgorithm 转换为 crpt.KeyType
|
||||
func (a SignatureAlgorithm) toKeyType() (crpt.KeyType, error) {
|
||||
switch a {
|
||||
case SM2Algorithm:
|
||||
return crpt.SM2, nil
|
||||
case Ed25519Algorithm:
|
||||
return crpt.Ed25519, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, a)
|
||||
}
|
||||
}
|
||||
|
||||
// KeyPair 通用密钥对,支持多种算法
|
||||
type KeyPair struct {
|
||||
Public crpt.PublicKey `json:"publicKey"`
|
||||
Private crpt.PrivateKey `json:"privateKey"`
|
||||
Algorithm SignatureAlgorithm
|
||||
}
|
||||
|
||||
// GenerateKeyPair 根据配置生成密钥对
|
||||
func GenerateKeyPair(config *CryptoConfig) (*KeyPair, error) {
|
||||
if config == nil {
|
||||
config = GetGlobalCryptoConfig()
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Generating key pair",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
)
|
||||
|
||||
keyType, err := config.SignatureAlgorithm.toKeyType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pub, priv, err := crpt.GenerateKey(keyType, rand.Reader)
|
||||
if err != nil {
|
||||
log.Error("Failed to generate key pair",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to generate %s key pair: %w", config.SignatureAlgorithm, err)
|
||||
}
|
||||
|
||||
log.Debug("Key pair generated successfully",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
)
|
||||
|
||||
return &KeyPair{
|
||||
Public: pub,
|
||||
Private: priv,
|
||||
Algorithm: config.SignatureAlgorithm,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Sign 使用密钥对签名数据
|
||||
func (kp *KeyPair) Sign(data []byte, rand io.Reader) ([]byte, error) {
|
||||
if rand == nil {
|
||||
rand = defaultRand()
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Signing data",
|
||||
"algorithm", kp.Algorithm,
|
||||
"dataLength", len(data),
|
||||
)
|
||||
|
||||
signature, err := crpt.SignMessage(kp.Private, data, rand, nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign data",
|
||||
"algorithm", kp.Algorithm,
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to sign with %s: %w", kp.Algorithm, err)
|
||||
}
|
||||
|
||||
log.Debug("Data signed successfully",
|
||||
"algorithm", kp.Algorithm,
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// Verify 使用公钥验证签名
|
||||
func (kp *KeyPair) Verify(data, signature []byte) (bool, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying signature",
|
||||
"algorithm", kp.Algorithm,
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
|
||||
ok, err := crpt.VerifyMessage(kp.Public, data, crpt.Signature(signature), nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to verify signature",
|
||||
"algorithm", kp.Algorithm,
|
||||
"error", err,
|
||||
)
|
||||
return false, fmt.Errorf("failed to verify with %s: %w", kp.Algorithm, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debug("Signature verified successfully",
|
||||
"algorithm", kp.Algorithm,
|
||||
)
|
||||
} else {
|
||||
log.Warn("Signature verification failed",
|
||||
"algorithm", kp.Algorithm,
|
||||
)
|
||||
}
|
||||
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// MarshalPrivateKey 序列化私钥
|
||||
func (kp *KeyPair) MarshalPrivateKey() ([]byte, error) {
|
||||
if kp.Private == nil {
|
||||
return nil, errors.New("private key is nil")
|
||||
}
|
||||
return kp.Private.Bytes(), nil
|
||||
}
|
||||
|
||||
// MarshalPublicKey 序列化公钥
|
||||
func (kp *KeyPair) MarshalPublicKey() ([]byte, error) {
|
||||
if kp.Public == nil {
|
||||
return nil, errors.New("public key is nil")
|
||||
}
|
||||
return kp.Public.Bytes(), nil
|
||||
}
|
||||
|
||||
// ParsePrivateKey 解析私钥
|
||||
func ParsePrivateKey(data []byte, algorithm SignatureAlgorithm) (crpt.PrivateKey, error) {
|
||||
keyType, err := algorithm.toKeyType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return crpt.PrivateKeyFromBytes(keyType, data)
|
||||
}
|
||||
|
||||
// ParsePublicKey 解析公钥
|
||||
func ParsePublicKey(data []byte, algorithm SignatureAlgorithm) (crpt.PublicKey, error) {
|
||||
keyType, err := algorithm.toKeyType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return crpt.PublicKeyFromBytes(keyType, data)
|
||||
}
|
||||
|
||||
// defaultRand 返回默认的随机数生成器
|
||||
func defaultRand() io.Reader {
|
||||
return rand.Reader
|
||||
}
|
||||
|
||||
// SignWithConfig 使用指定配置签名数据
|
||||
func SignWithConfig(data, privateKeyDER []byte, config *CryptoConfig) ([]byte, error) {
|
||||
if config == nil {
|
||||
config = GetGlobalCryptoConfig()
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Signing with config",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
"dataLength", len(data),
|
||||
)
|
||||
|
||||
privateKey, err := ParsePrivateKey(privateKeyDER, config.SignatureAlgorithm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
signature, err := crpt.SignMessage(privateKey, data, rand.Reader, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Signed with config successfully",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// VerifyWithConfig 使用指定配置验证签名
|
||||
func VerifyWithConfig(data, publicKeyDER, signature []byte, config *CryptoConfig) (bool, error) {
|
||||
if config == nil {
|
||||
config = GetGlobalCryptoConfig()
|
||||
}
|
||||
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying with config",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
"dataLength", len(data),
|
||||
)
|
||||
|
||||
publicKey, err := ParsePublicKey(publicKeyDER, config.SignatureAlgorithm)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
ok, err := crpt.VerifyMessage(publicKey, data, crpt.Signature(signature), nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to verify: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Verified with config",
|
||||
"algorithm", config.SignatureAlgorithm,
|
||||
"result", ok,
|
||||
)
|
||||
return ok, nil
|
||||
}
|
||||
251
api/model/crypto_config_test.go
Normal file
251
api/model/crypto_config_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestCryptoConfig_Validate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *model.CryptoConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid SM2 config",
|
||||
config: &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.SM2Algorithm,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid Ed25519 config",
|
||||
config: &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid signature algorithm",
|
||||
config: &model.CryptoConfig{
|
||||
SignatureAlgorithm: "rsa",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tt.config.Validate()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetGetGlobalCryptoConfig(t *testing.T) {
|
||||
// 不使用 t.Parallel(),因为它修改全局状态
|
||||
|
||||
// 保存当前配置
|
||||
original := model.GetGlobalCryptoConfig()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||
}
|
||||
|
||||
err := model.SetGlobalCryptoConfig(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved := model.GetGlobalCryptoConfig()
|
||||
assert.Equal(t, config.SignatureAlgorithm, retrieved.SignatureAlgorithm)
|
||||
|
||||
// 恢复原配置
|
||||
_ = model.SetGlobalCryptoConfig(original)
|
||||
}
|
||||
|
||||
func TestGenerateKeyPair_SM2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.SM2Algorithm,
|
||||
}
|
||||
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, keyPair)
|
||||
assert.NotNil(t, keyPair.Public)
|
||||
assert.NotNil(t, keyPair.Private)
|
||||
assert.Equal(t, model.SM2Algorithm, keyPair.Algorithm)
|
||||
}
|
||||
|
||||
func TestGenerateKeyPair_Ed25519(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||
}
|
||||
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, keyPair)
|
||||
assert.NotNil(t, keyPair.Public)
|
||||
assert.NotNil(t, keyPair.Private)
|
||||
assert.Equal(t, model.Ed25519Algorithm, keyPair.Algorithm)
|
||||
}
|
||||
|
||||
func TestKeyPair_SignAndVerify_SM2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.SM2Algorithm,
|
||||
}
|
||||
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("test data for SM2 signing")
|
||||
|
||||
// Sign
|
||||
signature, err := keyPair.Sign(data, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// Verify
|
||||
ok, err := keyPair.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Verify with wrong data should fail
|
||||
wrongData := []byte("wrong data")
|
||||
ok, err = keyPair.Verify(wrongData, signature)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestKeyPair_SignAndVerify_Ed25519(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.Ed25519Algorithm,
|
||||
}
|
||||
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := []byte("test data for Ed25519 signing")
|
||||
|
||||
// Sign
|
||||
signature, err := keyPair.Sign(data, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// Verify
|
||||
ok, err := keyPair.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Verify with wrong data should fail
|
||||
wrongData := []byte("wrong data")
|
||||
ok, err = keyPair.Verify(wrongData, signature)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestKeyPair_MarshalAndParse_SM2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: model.SM2Algorithm,
|
||||
}
|
||||
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Marshal private key
|
||||
privateKeyDER, err := keyPair.MarshalPrivateKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, privateKeyDER)
|
||||
|
||||
// Marshal public key
|
||||
publicKeyDER, err := keyPair.MarshalPublicKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, publicKeyDER)
|
||||
|
||||
// Parse keys back
|
||||
parsedPriv, err := model.ParsePrivateKey(privateKeyDER, model.SM2Algorithm)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, parsedPriv)
|
||||
|
||||
parsedPub, err := model.ParsePublicKey(publicKeyDER, model.SM2Algorithm)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, parsedPub)
|
||||
|
||||
// Test sign/verify with parsed keys
|
||||
data := []byte("test data")
|
||||
signature, err := model.SignWithConfig(data, privateKeyDER, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok, err := model.VerifyWithConfig(data, publicKeyDER, signature, config)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestSignWithConfig_And_VerifyWithConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
algorithm model.SignatureAlgorithm
|
||||
}{
|
||||
{
|
||||
name: "SM2",
|
||||
algorithm: model.SM2Algorithm,
|
||||
},
|
||||
{
|
||||
name: "Ed25519",
|
||||
algorithm: model.Ed25519Algorithm,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &model.CryptoConfig{
|
||||
SignatureAlgorithm: tt.algorithm,
|
||||
}
|
||||
|
||||
// Generate key pair
|
||||
keyPair, err := model.GenerateKeyPair(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Marshal keys
|
||||
privateKeyDER, err := keyPair.MarshalPrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := keyPair.MarshalPublicKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sign
|
||||
data := []byte("test data")
|
||||
signature, err := model.SignWithConfig(data, privateKeyDER, config)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// Verify
|
||||
ok, err := model.VerifyWithConfig(data, publicKeyDER, signature, config)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
501
api/model/envelope.go
Normal file
501
api/model/envelope.go
Normal file
@@ -0,0 +1,501 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
// Envelope 包装序列化后的数据,包含元信息和报文体。
|
||||
// 用于 Trustlog 接口类型的序列化和反序列化。
|
||||
type Envelope struct {
|
||||
ProducerID string // 日志提供者ID
|
||||
Signature []byte // 签名(根据客户端密钥与指定算法进行签名,二进制格式)
|
||||
Body []byte // CBOR序列化的报文体
|
||||
}
|
||||
|
||||
// EnvelopeConfig 序列化配置。
|
||||
type EnvelopeConfig struct {
|
||||
Signer Signer // 签名器,用于签名和验签
|
||||
}
|
||||
|
||||
// VerifyConfig 验签配置。
|
||||
type VerifyConfig struct {
|
||||
Signer Signer // 签名器,用于验签
|
||||
}
|
||||
|
||||
// NewEnvelopeConfig 创建Envelope配置。
|
||||
func NewEnvelopeConfig(signer Signer) EnvelopeConfig {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new EnvelopeConfig",
|
||||
"signerType", fmt.Sprintf("%T", signer),
|
||||
)
|
||||
return EnvelopeConfig{
|
||||
Signer: signer,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSM2EnvelopeConfig 创建使用SM2签名的Envelope配置。
|
||||
// 便捷方法,用于快速创建SM2签名器配置。
|
||||
func NewSM2EnvelopeConfig(privateKey, publicKey []byte) EnvelopeConfig {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new SM2 EnvelopeConfig",
|
||||
"privateKeyLength", len(privateKey),
|
||||
"publicKeyLength", len(publicKey),
|
||||
)
|
||||
return EnvelopeConfig{
|
||||
Signer: NewSM2Signer(privateKey, publicKey),
|
||||
}
|
||||
}
|
||||
|
||||
// NewVerifyConfig 创建验签配置。
|
||||
func NewVerifyConfig(signer Signer) VerifyConfig {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new VerifyConfig",
|
||||
"signerType", fmt.Sprintf("%T", signer),
|
||||
)
|
||||
return VerifyConfig{
|
||||
Signer: signer,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSM2VerifyConfig 创建使用SM2签名的验签配置。
|
||||
// 便捷方法,用于快速创建SM2签名器验签配置。
|
||||
// 注意:验签只需要公钥,但SM2Signer需要同时提供私钥和公钥(私钥可以为空)。
|
||||
func NewSM2VerifyConfig(publicKey []byte) VerifyConfig {
|
||||
return VerifyConfig{
|
||||
Signer: NewSM2Signer(nil, publicKey),
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// ===== Envelope 序列化/反序列化 =====
|
||||
//
|
||||
|
||||
// MarshalEnvelope 将 Envelope 序列化为 TLV 格式(Varint长度编码)。
|
||||
// 格式:[字段1长度][字段1值:producerID][字段2长度][字段2值:签名][字段3长度][字段3值:CBOR报文体]。
|
||||
func MarshalEnvelope(env *Envelope) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Marshaling envelope to TLV format")
|
||||
if env == nil {
|
||||
log.Error("Envelope is nil")
|
||||
return nil, errors.New("envelope cannot be nil")
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
writer := helpers.NewTLVWriter(buf)
|
||||
|
||||
log.Debug("Writing producerID to TLV",
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
if err := writer.WriteStringField(env.ProducerID); err != nil {
|
||||
log.Error("Failed to write producerID",
|
||||
"error", err,
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to write producerID: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Writing signature to TLV",
|
||||
"signatureLength", len(env.Signature),
|
||||
)
|
||||
if err := writer.WriteField(env.Signature); err != nil {
|
||||
log.Error("Failed to write signature",
|
||||
"error", err,
|
||||
"signatureLength", len(env.Signature),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to write signature: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Writing body to TLV",
|
||||
"bodyLength", len(env.Body),
|
||||
)
|
||||
if err := writer.WriteField(env.Body); err != nil {
|
||||
log.Error("Failed to write body",
|
||||
"error", err,
|
||||
"bodyLength", len(env.Body),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to write body: %w", err)
|
||||
}
|
||||
|
||||
result := buf.Bytes()
|
||||
log.Debug("Envelope marshaled successfully",
|
||||
"producerID", env.ProducerID,
|
||||
"totalLength", len(result),
|
||||
)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UnmarshalEnvelope 完整反序列化:读取所有字段。
|
||||
// 解析完整的Envelope结构,包括所有元数据和Body。
|
||||
// 为了向后兼容,如果遇到旧格式(包含原hash字段),会自动跳过该字段。
|
||||
func UnmarshalEnvelope(data []byte) (*Envelope, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Unmarshaling envelope from TLV format",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
if len(data) == 0 {
|
||||
log.Error("Data is empty")
|
||||
return nil, errors.New("data is empty")
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
reader := helpers.NewTLVReader(r)
|
||||
|
||||
log.Debug("Reading producerID from TLV")
|
||||
producerID, err := reader.ReadStringField()
|
||||
if err != nil {
|
||||
log.Error("Failed to read producerID",
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to read producerID: %w", err)
|
||||
}
|
||||
log.Debug("ProducerID read successfully",
|
||||
"producerID", producerID,
|
||||
)
|
||||
|
||||
// 读取第一个字段(可能是原hash或签名)
|
||||
log.Debug("Reading field 1 from TLV")
|
||||
field1, err := reader.ReadField()
|
||||
if err != nil {
|
||||
log.Error("Failed to read field 1",
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to read field 1: %w", err)
|
||||
}
|
||||
log.Debug("Field 1 read successfully",
|
||||
"field1Length", len(field1),
|
||||
)
|
||||
|
||||
// 读取第二个字段(可能是签名或body)
|
||||
log.Debug("Reading field 2 from TLV")
|
||||
field2, err := reader.ReadField()
|
||||
if err != nil {
|
||||
log.Error("Failed to read field 2",
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to read field 2: %w", err)
|
||||
}
|
||||
log.Debug("Field 2 read successfully",
|
||||
"field2Length", len(field2),
|
||||
)
|
||||
|
||||
// 尝试读取第三个字段来判断格式
|
||||
log.Debug("Attempting to read field 3 to determine format")
|
||||
field3, err := reader.ReadField()
|
||||
if err == nil {
|
||||
// 有第三个字段,说明是旧格式:producerID, originalHash, encryptedHash, body
|
||||
// field1 = originalHash, field2 = encryptedHash/signature, field3 = body
|
||||
log.Debug("Detected old format (with originalHash)",
|
||||
"producerID", producerID,
|
||||
"signatureLength", len(field2),
|
||||
"bodyLength", len(field3),
|
||||
)
|
||||
return &Envelope{
|
||||
ProducerID: producerID,
|
||||
Signature: field2,
|
||||
Body: field3,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 没有第三个字段,说明是新格式:producerID, signature, body
|
||||
// field1 = signature, field2 = body
|
||||
log.Debug("Detected new format (without originalHash)",
|
||||
"producerID", producerID,
|
||||
"signatureLength", len(field1),
|
||||
"bodyLength", len(field2),
|
||||
)
|
||||
return &Envelope{
|
||||
ProducerID: producerID,
|
||||
Signature: field1,
|
||||
Body: field2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 部分反序列化(无需反序列化全部报文) =====
|
||||
//
|
||||
|
||||
// UnmarshalEnvelopeProducerID 部分反序列化:只读取字段1(producerID)。
|
||||
// 用于快速获取producerID而不解析整个Envelope。
|
||||
func UnmarshalEnvelopeProducerID(data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("data is empty")
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
reader := helpers.NewTLVReader(r)
|
||||
|
||||
producerID, err := reader.ReadStringField()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read producerID: %w", err)
|
||||
}
|
||||
|
||||
return producerID, nil
|
||||
}
|
||||
|
||||
// UnmarshalEnvelopeSignature 部分反序列化:读取字段1、2(producerID, 签名)。
|
||||
// 用于获取签名信息而不解析整个Body。
|
||||
// 为了向后兼容,如果遇到旧格式(包含原hash字段),会自动跳过该字段。
|
||||
func UnmarshalEnvelopeSignature(data []byte) (string, []byte, error) {
|
||||
if len(data) == 0 {
|
||||
return "", nil, errors.New("data is empty")
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
reader := helpers.NewTLVReader(r)
|
||||
|
||||
producerID, err := reader.ReadStringField()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to read producerID: %w", err)
|
||||
}
|
||||
|
||||
// 读取第一个字段(可能是原hash或签名)
|
||||
field1, err := reader.ReadField()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to read field 1: %w", err)
|
||||
}
|
||||
|
||||
// 读取第二个字段(可能是签名或body)
|
||||
field2, err := reader.ReadField()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to read field 2: %w", err)
|
||||
}
|
||||
|
||||
// 尝试读取第三个字段来判断格式
|
||||
_, err = reader.ReadField()
|
||||
if err == nil {
|
||||
// 有第三个字段,说明是旧格式:producerID, originalHash, encryptedHash/signature, body
|
||||
// field1 = originalHash, field2 = signature
|
||||
return producerID, field2, nil
|
||||
}
|
||||
|
||||
// 没有第三个字段,说明是新格式:producerID, signature, body
|
||||
// field1 = signature
|
||||
return producerID, field1, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== Trustlog 序列化/反序列化 =====
|
||||
//
|
||||
|
||||
// MarshalTrustlog 序列化 Trustlog 为 Envelope 格式。
|
||||
// Trustlog 实现了 Trustlog 接口,自动提取 producerID 并使用 Canonical CBOR 编码。
|
||||
func MarshalTrustlog(t Trustlog, config EnvelopeConfig) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Marshaling Trustlog to Envelope format",
|
||||
"trustlogType", fmt.Sprintf("%T", t),
|
||||
)
|
||||
if t == nil {
|
||||
log.Error("Trustlog is nil")
|
||||
return nil, errors.New("trustlog cannot be nil")
|
||||
}
|
||||
|
||||
producerID := t.GetProducerID()
|
||||
if producerID == "" {
|
||||
log.Error("ProducerID is empty")
|
||||
return nil, errors.New("producerID cannot be empty")
|
||||
}
|
||||
log.Debug("ProducerID extracted",
|
||||
"producerID", producerID,
|
||||
)
|
||||
|
||||
// 1. 序列化CBOR报文体(使用 Trustlog 的 MarshalBinary,确保使用 Canonical CBOR)
|
||||
log.Debug("Marshaling trustlog to CBOR binary")
|
||||
bodyCBOR, err := t.MarshalBinary()
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal trustlog to CBOR",
|
||||
"error", err,
|
||||
"producerID", producerID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to marshal trustlog: %w", err)
|
||||
}
|
||||
log.Debug("Trustlog marshaled to CBOR successfully",
|
||||
"producerID", producerID,
|
||||
"bodyLength", len(bodyCBOR),
|
||||
)
|
||||
|
||||
// 2. 计算签名
|
||||
if config.Signer == nil {
|
||||
log.Error("Signer is nil")
|
||||
return nil, errors.New("signer is required")
|
||||
}
|
||||
log.Debug("Signing trustlog body",
|
||||
"producerID", producerID,
|
||||
"bodyLength", len(bodyCBOR),
|
||||
)
|
||||
signature, err := config.Signer.Sign(bodyCBOR)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign trustlog body",
|
||||
"error", err,
|
||||
"producerID", producerID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to sign data: %w", err)
|
||||
}
|
||||
log.Debug("Trustlog body signed successfully",
|
||||
"producerID", producerID,
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
|
||||
// 3. 构建Envelope
|
||||
env := &Envelope{
|
||||
ProducerID: producerID,
|
||||
Signature: signature,
|
||||
Body: bodyCBOR,
|
||||
}
|
||||
|
||||
// 4. 序列化为TLV格式
|
||||
log.Debug("Marshaling envelope to TLV format",
|
||||
"producerID", producerID,
|
||||
)
|
||||
return MarshalEnvelope(env)
|
||||
}
|
||||
|
||||
// UnmarshalTrustlog 反序列化 Envelope 为 Trustlog。
|
||||
// 解析Envelope数据并恢复 Trustlog 结构。
|
||||
func UnmarshalTrustlog(data []byte, t Trustlog) error {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Unmarshaling Envelope to Trustlog",
|
||||
"trustlogType", fmt.Sprintf("%T", t),
|
||||
"dataLength", len(data),
|
||||
)
|
||||
if t == nil {
|
||||
log.Error("Trustlog is nil")
|
||||
return errors.New("trustlog cannot be nil")
|
||||
}
|
||||
|
||||
env, err := UnmarshalEnvelope(data)
|
||||
if err != nil {
|
||||
log.Error("Failed to unmarshal envelope",
|
||||
"error", err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
log.Debug("Envelope unmarshaled successfully",
|
||||
"producerID", env.ProducerID,
|
||||
"bodyLength", len(env.Body),
|
||||
)
|
||||
|
||||
// 使用 Trustlog 的 UnmarshalBinary 反序列化
|
||||
log.Debug("Unmarshaling trustlog body from CBOR",
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
if errUnmarshal := t.UnmarshalBinary(env.Body); errUnmarshal != nil {
|
||||
log.Error("Failed to unmarshal trustlog body",
|
||||
"error", errUnmarshal,
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
return fmt.Errorf("failed to unmarshal trustlog body: %w", errUnmarshal)
|
||||
}
|
||||
log.Debug("Trustlog unmarshaled successfully",
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== Operation 序列化/反序列化 =====
|
||||
//
|
||||
|
||||
// MarshalOperation 序列化 Operation 为 Envelope 格式。
|
||||
func MarshalOperation(op *Operation, config EnvelopeConfig) ([]byte, error) {
|
||||
return MarshalTrustlog(op, config)
|
||||
}
|
||||
|
||||
// UnmarshalOperation 反序列化 Envelope 为 Operation。
|
||||
func UnmarshalOperation(data []byte) (*Operation, error) {
|
||||
var op Operation
|
||||
if err := UnmarshalTrustlog(data, &op); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &op, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== Record 序列化/反序列化 =====
|
||||
//
|
||||
|
||||
// MarshalRecord 序列化 Record 为 Envelope 格式。
|
||||
func MarshalRecord(record *Record, config EnvelopeConfig) ([]byte, error) {
|
||||
return MarshalTrustlog(record, config)
|
||||
}
|
||||
|
||||
// UnmarshalRecord 反序列化 Envelope 为 Record。
|
||||
func UnmarshalRecord(data []byte) (*Record, error) {
|
||||
var record Record
|
||||
if err := UnmarshalTrustlog(data, &record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 验证 =====
|
||||
//
|
||||
|
||||
// VerifyEnvelope 验证Envelope的完整性(使用EnvelopeConfig)。
|
||||
// 验证签名是否匹配,确保数据未被篡改。
|
||||
// 如果验证成功,返回解析后的Envelope结构体指针;如果验证失败,返回错误。
|
||||
func VerifyEnvelope(data []byte, config EnvelopeConfig) (*Envelope, error) {
|
||||
if config.Signer == nil {
|
||||
return nil, errors.New("signer is required for verification")
|
||||
}
|
||||
|
||||
verifyConfig := VerifyConfig(config)
|
||||
return VerifyEnvelopeWithConfig(data, verifyConfig)
|
||||
}
|
||||
|
||||
// VerifyEnvelopeWithConfig 验证Envelope的完整性(使用VerifyConfig)。
|
||||
// 验证签名是否匹配,确保数据未被篡改。
|
||||
// 如果验证成功,返回解析后的Envelope结构体指针;如果验证失败,返回错误。
|
||||
func VerifyEnvelopeWithConfig(data []byte, config VerifyConfig) (*Envelope, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying envelope",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
if config.Signer == nil {
|
||||
log.Error("Signer is nil")
|
||||
return nil, errors.New("signer is required for verification")
|
||||
}
|
||||
|
||||
env, err := UnmarshalEnvelope(data)
|
||||
if err != nil {
|
||||
log.Error("Failed to unmarshal envelope",
|
||||
"error", err,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to unmarshal envelope: %w", err)
|
||||
}
|
||||
log.Debug("Envelope unmarshaled for verification",
|
||||
"producerID", env.ProducerID,
|
||||
"bodyLength", len(env.Body),
|
||||
"signatureLength", len(env.Signature),
|
||||
)
|
||||
|
||||
// 验证签名
|
||||
log.Debug("Verifying signature",
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
valid, err := config.Signer.Verify(env.Body, env.Signature)
|
||||
if err != nil {
|
||||
log.Error("Failed to verify signature",
|
||||
"error", err,
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to verify signature: %w", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
log.Warn("Signature verification failed",
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
return nil, errors.New("signature verification failed")
|
||||
}
|
||||
|
||||
log.Debug("Envelope verified successfully",
|
||||
"producerID", env.ProducerID,
|
||||
)
|
||||
return env, nil
|
||||
}
|
||||
215
api/model/envelope_debug_test.go
Normal file
215
api/model/envelope_debug_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
// TestSignVerifyDataConsistency 详细测试加签和验签的数据一致性.
|
||||
func TestSignVerifyDataConsistency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 序列化为DER格式
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名器
|
||||
signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||
|
||||
// 测试数据1
|
||||
testData1 := []byte("test data for signing")
|
||||
|
||||
// 测试数据2(不同数据)
|
||||
testData2 := []byte("different test data")
|
||||
|
||||
// 1. 对testData1签名
|
||||
signature1, err := signer.Sign(testData1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signature1)
|
||||
|
||||
// 2. 用testData1验证signature1 - 应该成功
|
||||
valid, err := signer.Verify(testData1, signature1)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid, "使用相同数据验证应该成功")
|
||||
|
||||
// 3. 用testData2验证signature1 - 应该失败
|
||||
valid, err = signer.Verify(testData2, signature1)
|
||||
require.Error(t, err, "使用不同数据验证应该失败")
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
assert.False(t, valid)
|
||||
|
||||
// 4. 对testData2签名
|
||||
signature2, err := signer.Sign(testData2)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signature2)
|
||||
|
||||
// 5. 用testData2验证signature2 - 应该成功
|
||||
valid, err = signer.Verify(testData2, signature2)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid, "使用相同数据验证应该成功")
|
||||
|
||||
// 6. 用testData1验证signature2 - 应该失败
|
||||
valid, err = signer.Verify(testData1, signature2)
|
||||
require.Error(t, err, "使用不同数据验证应该失败")
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
assert.False(t, valid)
|
||||
|
||||
t.Logf("测试完成:签名和验证逻辑正常")
|
||||
}
|
||||
|
||||
// TestEnvelopeBodyTampering 测试修改envelope body后验签应该失败.
|
||||
func TestEnvelopeBodyTampering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 序列化为DER格式
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名配置
|
||||
signConfig := model.NewSM2EnvelopeConfig(privateKeyDER, publicKeyDER)
|
||||
verifyConfig := model.NewSM2VerifyConfig(publicKeyDER)
|
||||
|
||||
// 创建测试Operation
|
||||
op := &model.Operation{
|
||||
OpID: "op-test-002",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/456",
|
||||
ProducerID: "producer-2",
|
||||
OpActor: "actor-2",
|
||||
}
|
||||
|
||||
err = op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. 加签:序列化为Envelope
|
||||
envelopeData, err := model.MarshalOperation(op, signConfig)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, envelopeData)
|
||||
|
||||
// 2. 验签:验证原始Envelope - 应该成功
|
||||
verifiedEnv, err := model.VerifyEnvelopeWithConfig(envelopeData, verifyConfig)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, verifiedEnv)
|
||||
|
||||
// 3. 反序列化获取原始body
|
||||
originalEnv, err := model.UnmarshalEnvelope(envelopeData)
|
||||
require.NoError(t, err)
|
||||
originalBody := originalEnv.Body
|
||||
originalSignature := originalEnv.Signature
|
||||
|
||||
t.Logf("原始body长度: %d", len(originalBody))
|
||||
t.Logf("原始签名长度: %d", len(originalSignature))
|
||||
|
||||
// 4. 创建修改后的body(完全不同的数据)
|
||||
modifiedBody := []byte("completely different body content")
|
||||
require.NotEqual(t, originalBody, modifiedBody, "修改后的body应该不同")
|
||||
|
||||
// 5. 创建修改后的envelope(使用原始签名但修改body)
|
||||
modifiedEnv := &model.Envelope{
|
||||
ProducerID: originalEnv.ProducerID,
|
||||
Signature: originalSignature, // 使用原始签名
|
||||
Body: modifiedBody, // 使用修改后的body
|
||||
}
|
||||
modifiedData, err := model.MarshalEnvelope(modifiedEnv)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 6. 验签修改后的envelope - 应该失败
|
||||
_, err = model.VerifyEnvelopeWithConfig(modifiedData, verifyConfig)
|
||||
require.Error(t, err, "修改body后验签应该失败")
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
|
||||
t.Logf("测试完成:修改body后验签正确失败")
|
||||
}
|
||||
|
||||
// TestEnvelopeSignatureTampering 测试修改envelope signature后验签应该失败.
|
||||
func TestEnvelopeSignatureTampering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 序列化为DER格式
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名配置
|
||||
signConfig := model.NewSM2EnvelopeConfig(privateKeyDER, publicKeyDER)
|
||||
verifyConfig := model.NewSM2VerifyConfig(publicKeyDER)
|
||||
|
||||
// 创建测试Operation
|
||||
op := &model.Operation{
|
||||
OpID: "op-test-003",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/789",
|
||||
ProducerID: "producer-3",
|
||||
OpActor: "actor-3",
|
||||
}
|
||||
|
||||
err = op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. 加签:序列化为Envelope
|
||||
envelopeData, err := model.MarshalOperation(op, signConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. 反序列化获取原始signature
|
||||
originalEnv, err := model.UnmarshalEnvelope(envelopeData)
|
||||
require.NoError(t, err)
|
||||
originalSignature := originalEnv.Signature
|
||||
|
||||
// 3. 创建修改后的signature(完全不同的数据)
|
||||
modifiedSignature := make([]byte, len(originalSignature))
|
||||
copy(modifiedSignature, originalSignature)
|
||||
// 修改最后一个字节
|
||||
if len(modifiedSignature) > 0 {
|
||||
modifiedSignature[len(modifiedSignature)-1] ^= 0xFF
|
||||
}
|
||||
require.NotEqual(t, originalSignature, modifiedSignature, "修改后的signature应该不同")
|
||||
|
||||
// 4. 创建修改后的envelope(使用原始body但修改signature)
|
||||
modifiedEnv := &model.Envelope{
|
||||
ProducerID: originalEnv.ProducerID,
|
||||
Signature: modifiedSignature, // 使用修改后的signature
|
||||
Body: originalEnv.Body, // 使用原始body
|
||||
}
|
||||
modifiedData, err := model.MarshalEnvelope(modifiedEnv)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 5. 验签修改后的envelope - 应该失败
|
||||
_, err = model.VerifyEnvelopeWithConfig(modifiedData, verifyConfig)
|
||||
require.Error(t, err, "修改signature后验签应该失败")
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
|
||||
t.Logf("测试完成:修改signature后验签正确失败")
|
||||
}
|
||||
126
api/model/envelope_sign_verify_test.go
Normal file
126
api/model/envelope_sign_verify_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
// TestSignVerifyConsistency 测试加签和验签的一致性
|
||||
// 验证加签时使用的数据和验签时使用的数据是否一致.
|
||||
func TestSignVerifyConsistency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 序列化为DER格式
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名配置
|
||||
signConfig := model.NewSM2EnvelopeConfig(privateKeyDER, publicKeyDER)
|
||||
verifyConfig := model.NewSM2VerifyConfig(publicKeyDER)
|
||||
|
||||
// 创建测试Operation
|
||||
op := &model.Operation{
|
||||
OpID: "op-test-001",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err = op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. 加签:序列化为Envelope
|
||||
envelopeData, err := model.MarshalOperation(op, signConfig)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, envelopeData)
|
||||
|
||||
// 2. 验签:验证Envelope
|
||||
verifiedEnv, err := model.VerifyEnvelopeWithConfig(envelopeData, verifyConfig)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, verifiedEnv)
|
||||
|
||||
// 3. 验证:加签时使用的body和验签时使用的body应该一致
|
||||
// 手动反序列化envelope以获取body
|
||||
originalEnv, err := model.UnmarshalEnvelope(envelopeData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证body一致
|
||||
assert.Equal(t, originalEnv.Body, verifiedEnv.Body, "加签和验签时使用的body应该完全一致")
|
||||
assert.Equal(t, originalEnv.ProducerID, verifiedEnv.ProducerID)
|
||||
assert.Equal(t, originalEnv.Signature, verifiedEnv.Signature)
|
||||
|
||||
// 4. 验证:如果修改body,验签应该失败
|
||||
// 创建完全不同的body内容
|
||||
modifiedBody := []byte("completely different body content")
|
||||
require.NotEqual(t, originalEnv.Body, modifiedBody, "修改后的body应该不同")
|
||||
|
||||
modifiedEnv := &model.Envelope{
|
||||
ProducerID: originalEnv.ProducerID,
|
||||
Signature: originalEnv.Signature, // 使用旧的签名
|
||||
Body: modifiedBody, // 使用修改后的body
|
||||
}
|
||||
modifiedData, err := model.MarshalEnvelope(modifiedEnv)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验签应该失败,因为body被修改了但签名还是旧的
|
||||
_, err = model.VerifyEnvelopeWithConfig(modifiedData, verifyConfig)
|
||||
require.Error(t, err, "修改body后验签应该失败")
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
}
|
||||
|
||||
// TestSignVerifyDirectData 直接测试对相同数据的签名和验证.
|
||||
func TestSignVerifyDirectData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 序列化为DER格式
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建签名器
|
||||
signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||
|
||||
// 测试数据
|
||||
testData := []byte("test data for signing")
|
||||
|
||||
// 1. 签名
|
||||
signature, err := signer.Sign(testData)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signature)
|
||||
|
||||
// 2. 验证(使用相同的数据)
|
||||
valid, err := signer.Verify(testData, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid, "使用相同数据验证应该成功")
|
||||
|
||||
// 3. 验证(使用不同的数据)
|
||||
modifiedData := []byte("modified test data")
|
||||
valid, err = signer.Verify(modifiedData, signature)
|
||||
// VerifySignature在验证失败时会返回错误,这是预期的
|
||||
require.Error(t, err, "使用不同数据验证应该失败并返回错误")
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
assert.False(t, valid)
|
||||
}
|
||||
423
api/model/envelope_test.go
Normal file
423
api/model/envelope_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestNewEnvelopeConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
config := model.NewEnvelopeConfig(signer)
|
||||
assert.NotNil(t, config.Signer)
|
||||
}
|
||||
|
||||
func TestNewSM2EnvelopeConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privateKey := []byte("test-private-key")
|
||||
publicKey := []byte("test-public-key")
|
||||
|
||||
config := model.NewSM2EnvelopeConfig(privateKey, publicKey)
|
||||
assert.NotNil(t, config.Signer)
|
||||
}
|
||||
|
||||
func TestNewVerifyConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
config := model.NewVerifyConfig(signer)
|
||||
assert.NotNil(t, config.Signer)
|
||||
}
|
||||
|
||||
func TestNewSM2VerifyConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
|
||||
config := model.NewSM2VerifyConfig(publicKey)
|
||||
assert.NotNil(t, config.Signer)
|
||||
}
|
||||
|
||||
func TestMarshalEnvelope_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.MarshalEnvelope(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "envelope cannot be nil")
|
||||
}
|
||||
|
||||
func TestMarshalEnvelope_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
assert.NotEmpty(t, data)
|
||||
}
|
||||
|
||||
func TestMarshalEnvelope_EmptyFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "",
|
||||
Signature: []byte{},
|
||||
Body: []byte{},
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelope_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.UnmarshalEnvelope(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data is empty")
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelope_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.UnmarshalEnvelope([]byte{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data is empty")
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalEnvelope_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
// Marshal
|
||||
data, err := model.MarshalEnvelope(original)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// Unmarshal
|
||||
result, err := model.UnmarshalEnvelope(data)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Verify
|
||||
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||
assert.Equal(t, original.Signature, result.Signature)
|
||||
assert.Equal(t, original.Body, result.Body)
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelopeProducerID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
|
||||
producerID, err := model.UnmarshalEnvelopeProducerID(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "producer-1", producerID)
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelopeSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
|
||||
producerID, signature, err := model.UnmarshalEnvelopeSignature(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "producer-1", producerID)
|
||||
assert.Equal(t, []byte("signature"), signature)
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelopeSignature_EmptyData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, err := model.UnmarshalEnvelopeSignature(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data is empty")
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelopeSignature_InvalidData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, err := model.UnmarshalEnvelopeSignature([]byte{0xff, 0xff})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUnmarshalEnvelopeProducerID_EmptyData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.UnmarshalEnvelopeProducerID(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data is empty")
|
||||
}
|
||||
|
||||
func TestMarshalTrustlog_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.MarshalTrustlog(nil, model.EnvelopeConfig{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "trustlog cannot be nil")
|
||||
}
|
||||
|
||||
func TestMarshalTrustlog_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
data, err := model.MarshalTrustlog(op, config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
|
||||
func TestUnmarshalTrustlog_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{}
|
||||
err := model.UnmarshalTrustlog(nil, op)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data is empty")
|
||||
}
|
||||
|
||||
func TestMarshalOperation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
data, err := model.MarshalOperation(op, config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
|
||||
func TestUnmarshalOperation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
data, err := model.MarshalOperation(op, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := model.UnmarshalOperation(data)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, op.OpID, result.OpID)
|
||||
}
|
||||
|
||||
func TestMarshalRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
err := rec.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
data, err := model.MarshalRecord(rec, config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
|
||||
func TestUnmarshalRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
err := rec.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
data, err := model.MarshalRecord(rec, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := model.UnmarshalRecord(data)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, rec.ID, result.ID)
|
||||
}
|
||||
|
||||
func TestVerifyEnvelope_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
env, err := model.VerifyEnvelope(nil, config)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, env)
|
||||
assert.Contains(t, err.Error(), "data is empty")
|
||||
}
|
||||
|
||||
func TestVerifyEnvelope_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
verifiedEnv, err := model.VerifyEnvelope(data, config)
|
||||
// NopSigner verifies by comparing body with signature
|
||||
// Since signature != body, verification should fail
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, verifiedEnv)
|
||||
}
|
||||
|
||||
func TestVerifyEnvelopeWithConfig_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := model.NewVerifyConfig(model.NewNopSigner())
|
||||
env, err := model.VerifyEnvelopeWithConfig(nil, config)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, env)
|
||||
// Error message may vary, just check that it's an error
|
||||
assert.NotEmpty(t, err.Error())
|
||||
}
|
||||
|
||||
func TestVerifyEnvelopeWithConfig_NilSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.VerifyConfig{Signer: nil}
|
||||
verifiedEnv, err := model.VerifyEnvelopeWithConfig(data, config)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, verifiedEnv)
|
||||
assert.Contains(t, err.Error(), "signer is required")
|
||||
}
|
||||
|
||||
func TestVerifyEnvelopeWithConfig_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create envelope with matching body and signature (NopSigner requirement)
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("body"), // Same as body for NopSigner
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.NewVerifyConfig(model.NewNopSigner())
|
||||
verifiedEnv, err := model.VerifyEnvelopeWithConfig(data, config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, verifiedEnv)
|
||||
assert.Equal(t, env.ProducerID, verifiedEnv.ProducerID)
|
||||
assert.Equal(t, env.Signature, verifiedEnv.Signature)
|
||||
assert.Equal(t, env.Body, verifiedEnv.Body)
|
||||
}
|
||||
|
||||
func TestVerifyEnvelope_NilSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
env := &model.Envelope{
|
||||
ProducerID: "producer-1",
|
||||
Signature: []byte("signature"),
|
||||
Body: []byte("body"),
|
||||
}
|
||||
|
||||
data, err := model.MarshalEnvelope(env)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.EnvelopeConfig{Signer: nil}
|
||||
verifiedEnv, err := model.VerifyEnvelope(data, config)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, verifiedEnv)
|
||||
assert.Contains(t, err.Error(), "signer is required")
|
||||
}
|
||||
267
api/model/hash.go
Normal file
267
api/model/hash.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
stdsha256 "crypto/sha256"
|
||||
stdsha512 "crypto/sha512"
|
||||
"encoding/hex"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
miniosha256 "github.com/minio/sha256-simd"
|
||||
"github.com/zeebo/blake3"
|
||||
"golang.org/x/crypto/blake2b"
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/md4" //nolint:staticcheck // 保留弱加密算法以支持遗留系统兼容性
|
||||
"golang.org/x/crypto/ripemd160" //nolint:staticcheck // 保留弱加密算法以支持遗留系统兼容性
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
// HashType 定义支持的哈希算法类型.
|
||||
type HashType string
|
||||
|
||||
const (
|
||||
MD5 HashType = "md5"
|
||||
SHA1 HashType = "sha1"
|
||||
SHA224 HashType = "sha224"
|
||||
SHA256 HashType = "sha256"
|
||||
SHA384 HashType = "sha384"
|
||||
SHA512 HashType = "sha512"
|
||||
Sha512224 HashType = "sha512_224"
|
||||
Sha512256 HashType = "sha512_256"
|
||||
|
||||
Sha256Simd HashType = "sha256-simd"
|
||||
BLAKE3 HashType = "blake3"
|
||||
BLAKE2B HashType = "blake2b"
|
||||
BLAKE2S HashType = "blake2s"
|
||||
MD4 HashType = "md4"
|
||||
RIPEMD160 HashType = "ripemd160"
|
||||
Sha3224 HashType = "sha3-224"
|
||||
Sha3256 HashType = "sha3-256"
|
||||
Sha3384 HashType = "sha3-384"
|
||||
Sha3512 HashType = "sha3-512"
|
||||
)
|
||||
|
||||
// 使用 map 来存储支持的算法,提高查找效率.
|
||||
//
|
||||
//nolint:gochecknoglobals // 全局缓存用于算法查找和实例复用.
|
||||
var (
|
||||
supportedAlgorithms []string
|
||||
supportedAlgorithmsMap map[string]bool
|
||||
supportedAlgorithmsOnce sync.Once
|
||||
|
||||
// 享元模式:存储已创建的 HashTool 实例.
|
||||
toolPool = make(map[HashType]*HashTool)
|
||||
poolMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// HashTool 哈希工具类.
|
||||
type HashTool struct {
|
||||
hashType HashType
|
||||
}
|
||||
|
||||
// GetHashTool 获取指定类型的 HashTool.
|
||||
func GetHashTool(hashType HashType) *HashTool {
|
||||
poolMutex.RLock()
|
||||
if tool, exists := toolPool[hashType]; exists {
|
||||
poolMutex.RUnlock()
|
||||
return tool
|
||||
}
|
||||
poolMutex.RUnlock()
|
||||
|
||||
poolMutex.Lock()
|
||||
defer poolMutex.Unlock()
|
||||
|
||||
if tool, exists := toolPool[hashType]; exists {
|
||||
return tool
|
||||
}
|
||||
|
||||
tool := &HashTool{hashType: hashType}
|
||||
toolPool[hashType] = tool
|
||||
return tool
|
||||
}
|
||||
|
||||
// NewHashTool 创建新的哈希工具实例.
|
||||
func NewHashTool(hashType HashType) *HashTool {
|
||||
return &HashTool{hashType: hashType}
|
||||
}
|
||||
|
||||
// getHasher 根据哈希类型获取对应的哈希器.
|
||||
func (h *HashTool) getHasher() hash.Hash {
|
||||
switch h.hashType {
|
||||
case MD5:
|
||||
return md5.New()
|
||||
case SHA1:
|
||||
return sha1.New()
|
||||
case SHA224:
|
||||
return stdsha256.New224()
|
||||
case SHA256:
|
||||
return stdsha256.New()
|
||||
case SHA384:
|
||||
return stdsha512.New384()
|
||||
case SHA512:
|
||||
return stdsha512.New()
|
||||
case Sha512224:
|
||||
return stdsha512.New512_224()
|
||||
case Sha512256:
|
||||
return stdsha512.New512_256()
|
||||
|
||||
// 第三方算法
|
||||
case Sha256Simd:
|
||||
return miniosha256.New()
|
||||
case BLAKE3:
|
||||
return blake3.New()
|
||||
case BLAKE2B:
|
||||
hasher, _ := blake2b.New512(nil)
|
||||
return hasher
|
||||
case BLAKE2S:
|
||||
hasher, _ := blake2s.New256(nil)
|
||||
return hasher
|
||||
case MD4:
|
||||
return md4.New()
|
||||
case RIPEMD160:
|
||||
return ripemd160.New()
|
||||
case Sha3224:
|
||||
return sha3.New224()
|
||||
case Sha3256:
|
||||
return sha3.New256()
|
||||
case Sha3384:
|
||||
return sha3.New384()
|
||||
case Sha3512:
|
||||
return sha3.New512()
|
||||
|
||||
default:
|
||||
return stdsha256.New() // 默认使用 SHA256
|
||||
}
|
||||
}
|
||||
|
||||
// hashData 通用的哈希计算函数.
|
||||
func (h *HashTool) hashData(processFunc func(hasher hash.Hash) error) (string, error) {
|
||||
hasher := h.getHasher()
|
||||
if err := processFunc(hasher); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// HashString 对字符串进行哈希计算.
|
||||
func (h *HashTool) HashString(data string) (string, error) {
|
||||
return h.hashData(func(hasher hash.Hash) error {
|
||||
_, err := hasher.Write([]byte(data))
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// HashBytes 对字节数组进行哈希计算.
|
||||
func (h *HashTool) HashBytes(data []byte) (string, error) {
|
||||
return h.hashData(func(hasher hash.Hash) error {
|
||||
_, err := hasher.Write(data)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// HashBytesRaw 对字节数组进行哈希计算,返回原始字节数组(非hex字符串).
|
||||
func (h *HashTool) HashBytesRaw(data []byte) ([]byte, error) {
|
||||
hasher := h.getHasher()
|
||||
if _, err := hasher.Write(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hasher.Sum(nil), nil
|
||||
}
|
||||
|
||||
// HashFile 对文件进行哈希计算.
|
||||
func (h *HashTool) HashFile(_ context.Context, filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return h.hashData(func(hasher hash.Hash) error {
|
||||
_, copyErr := io.Copy(hasher, file)
|
||||
return copyErr
|
||||
})
|
||||
}
|
||||
|
||||
// HashStream 对流数据进行哈希计算.
|
||||
func (h *HashTool) HashStream(reader io.Reader) (string, error) {
|
||||
return h.hashData(func(hasher hash.Hash) error {
|
||||
_, err := io.Copy(hasher, reader)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// initSupportedAlgorithms 初始化支持的算法数据.
|
||||
func initSupportedAlgorithms() {
|
||||
algorithms := []HashType{
|
||||
MD5, SHA1, SHA224, SHA256, SHA384, SHA512,
|
||||
Sha512224, Sha512256, Sha256Simd, BLAKE3,
|
||||
BLAKE2B, BLAKE2S, MD4, RIPEMD160,
|
||||
Sha3224, Sha3256, Sha3384, Sha3512,
|
||||
}
|
||||
|
||||
supportedAlgorithms = make([]string, len(algorithms))
|
||||
supportedAlgorithmsMap = make(map[string]bool, len(algorithms))
|
||||
|
||||
for i, alg := range algorithms {
|
||||
algStr := string(alg)
|
||||
supportedAlgorithms[i] = algStr
|
||||
supportedAlgorithmsMap[strings.ToLower(algStr)] = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetSupportedAlgorithms 获取支持的哈希算法列表.
|
||||
func GetSupportedAlgorithms() []string {
|
||||
supportedAlgorithmsOnce.Do(initSupportedAlgorithms)
|
||||
return supportedAlgorithms
|
||||
}
|
||||
|
||||
// IsAlgorithmSupported 检查算法是否支持 - 使用 map 提高性能.
|
||||
func IsAlgorithmSupported(algorithm string) bool {
|
||||
supportedAlgorithmsOnce.Do(initSupportedAlgorithms)
|
||||
return supportedAlgorithmsMap[strings.ToLower(algorithm)]
|
||||
}
|
||||
|
||||
// CompareHash 比较哈希值.
|
||||
func (h *HashTool) CompareHash(data, expectedHash string) (bool, error) {
|
||||
actualHash, err := h.HashString(data)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return strings.EqualFold(actualHash, expectedHash), nil
|
||||
}
|
||||
|
||||
// CompareFileHash 比较文件哈希值.
|
||||
func (h *HashTool) CompareFileHash(ctx context.Context, filePath, expectedHash string) (bool, error) {
|
||||
actualHash, err := h.HashFile(ctx, filePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return strings.EqualFold(actualHash, expectedHash), nil
|
||||
}
|
||||
|
||||
// GetHashType 获取当前工具使用的哈希类型.
|
||||
func (h *HashTool) GetHashType() HashType {
|
||||
return h.hashType
|
||||
}
|
||||
|
||||
type HashData interface {
|
||||
Key() string
|
||||
Hash() string
|
||||
Type() HashType
|
||||
}
|
||||
|
||||
type Hashable interface {
|
||||
DoHash(ctx context.Context) (HashData, error)
|
||||
}
|
||||
|
||||
type HashList []HashData
|
||||
|
||||
func (h HashList) GetHashType() HashType {
|
||||
return h[0].Type()
|
||||
}
|
||||
545
api/model/hash_test.go
Normal file
545
api/model/hash_test.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestGetHashTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hashType model.HashType
|
||||
}{
|
||||
{
|
||||
name: "SHA256",
|
||||
hashType: model.SHA256,
|
||||
},
|
||||
{
|
||||
name: "SHA256Simd",
|
||||
hashType: model.Sha256Simd,
|
||||
},
|
||||
{
|
||||
name: "MD5",
|
||||
hashType: model.MD5,
|
||||
},
|
||||
{
|
||||
name: "SHA1",
|
||||
hashType: model.SHA1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool := model.GetHashTool(tt.hashType)
|
||||
assert.NotNil(t, tool)
|
||||
// Verify it works
|
||||
_, err := tool.HashString("test")
|
||||
require.NoError(t, err)
|
||||
// Verify hash type
|
||||
assert.Equal(t, tt.hashType, tool.GetHashType())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHashTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
assert.NotNil(t, tool)
|
||||
// Verify it works
|
||||
_, err := tool.HashString("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHashTool_HashString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hashType model.HashType
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "SHA256",
|
||||
hashType: model.SHA256,
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "SHA256Simd",
|
||||
hashType: model.Sha256Simd,
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MD5",
|
||||
hashType: model.MD5,
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "SHA1",
|
||||
hashType: model.SHA1,
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "SHA512",
|
||||
hashType: model.SHA512,
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
hashType: model.SHA256,
|
||||
input: "",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool := model.NewHashTool(tt.hashType)
|
||||
result, err := tool.HashString(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashTool_HashBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hashType model.HashType
|
||||
input []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "SHA256",
|
||||
hashType: model.SHA256,
|
||||
input: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "SHA256Simd",
|
||||
hashType: model.Sha256Simd,
|
||||
input: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty bytes",
|
||||
hashType: model.SHA256,
|
||||
input: []byte{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "large input",
|
||||
hashType: model.SHA256,
|
||||
input: make([]byte, 1000),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool := model.NewHashTool(tt.hashType)
|
||||
result, err := tool.HashBytes(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashTool_Deterministic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
input := "test string"
|
||||
|
||||
result1, err1 := tool.HashString(input)
|
||||
require.NoError(t, err1)
|
||||
|
||||
result2, err2 := tool.HashString(input)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Same input should produce same hash
|
||||
assert.Equal(t, result1, result2)
|
||||
}
|
||||
|
||||
func TestHashTool_DifferentInputs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
|
||||
result1, err1 := tool.HashString("input1")
|
||||
require.NoError(t, err1)
|
||||
|
||||
result2, err2 := tool.HashString("input2")
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Different inputs should produce different hashes
|
||||
assert.NotEqual(t, result1, result2)
|
||||
}
|
||||
|
||||
func TestHashTool_StringVsBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
input := "test"
|
||||
|
||||
stringHash, err1 := tool.HashString(input)
|
||||
require.NoError(t, err1)
|
||||
|
||||
bytesHash, err2 := tool.HashBytes([]byte(input))
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Same data in different formats should produce same hash
|
||||
assert.Equal(t, stringHash, bytesHash)
|
||||
}
|
||||
|
||||
func TestHashTool_MultipleTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "test"
|
||||
hashTypes := []model.HashType{
|
||||
model.MD5,
|
||||
model.SHA1,
|
||||
model.SHA256,
|
||||
model.SHA512,
|
||||
model.Sha256Simd,
|
||||
}
|
||||
|
||||
results := make(map[model.HashType]string)
|
||||
for _, hashType := range hashTypes {
|
||||
tool := model.NewHashTool(hashType)
|
||||
result, err := tool.HashString(input)
|
||||
require.NoError(t, err)
|
||||
results[hashType] = result
|
||||
}
|
||||
|
||||
// All should produce different hashes (except possibly some edge cases)
|
||||
// At minimum, verify they all produced valid hashes
|
||||
for hashType, result := range results {
|
||||
assert.NotEmpty(t, result, "HashType: %v", hashType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashTool_GetHashTool_Caching(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hashType := model.SHA256
|
||||
tool1 := model.GetHashTool(hashType)
|
||||
tool2 := model.GetHashTool(hashType)
|
||||
|
||||
// Should return the same instance (cached)
|
||||
assert.Equal(t, tool1, tool2)
|
||||
}
|
||||
|
||||
func TestHashTool_HashFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temporary file
|
||||
tmpFile := t.TempDir() + "/test.txt"
|
||||
err := os.WriteFile(tmpFile, []byte("test content"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
ctx := context.Background()
|
||||
result, err := tool.HashFile(ctx, tmpFile)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
|
||||
func TestHashTool_HashFile_NotExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
ctx := context.Background()
|
||||
_, err := tool.HashFile(ctx, "/nonexistent/file")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHashTool_HashStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
reader := bytes.NewReader([]byte("test content"))
|
||||
|
||||
result, err := tool.HashStream(reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
|
||||
func TestHashTool_HashStream_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
reader := bytes.NewReader([]byte{})
|
||||
|
||||
result, err := tool.HashStream(reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, result) // Even empty input produces a hash
|
||||
}
|
||||
|
||||
func TestGetSupportedAlgorithms(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
algorithms := model.GetSupportedAlgorithms()
|
||||
assert.NotEmpty(t, algorithms)
|
||||
assert.Contains(t, algorithms, string(model.SHA256))
|
||||
assert.Contains(t, algorithms, string(model.Sha256Simd))
|
||||
// Verify case-insensitive check
|
||||
assert.True(t, model.IsAlgorithmSupported("SHA256"))
|
||||
assert.True(t, model.IsAlgorithmSupported("sha256"))
|
||||
assert.True(t, model.IsAlgorithmSupported("sha256-simd"))
|
||||
}
|
||||
|
||||
func TestIsAlgorithmSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
algorithm string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "SHA256",
|
||||
algorithm: "SHA256",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SHA256 lowercase",
|
||||
algorithm: "sha256",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Sha256Simd",
|
||||
algorithm: "sha256-simd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Sha256Simd mixed case",
|
||||
algorithm: "Sha256-Simd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
algorithm: "UNSUPPORTED",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := model.IsAlgorithmSupported(tt.algorithm)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashTool_GetHashType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA512)
|
||||
assert.Equal(t, model.SHA512, tool.GetHashType())
|
||||
}
|
||||
|
||||
func TestHashTool_AllHashTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hashTypes := []model.HashType{
|
||||
model.MD5,
|
||||
model.SHA1,
|
||||
model.SHA224,
|
||||
model.SHA256,
|
||||
model.SHA384,
|
||||
model.SHA512,
|
||||
model.Sha256Simd,
|
||||
model.BLAKE3,
|
||||
}
|
||||
|
||||
for _, hashType := range hashTypes {
|
||||
tool := model.NewHashTool(hashType)
|
||||
result, err := tool.HashString("test")
|
||||
require.NoError(t, err, "HashType: %v", hashType)
|
||||
assert.NotEmpty(t, result, "HashType: %v", hashType)
|
||||
assert.Equal(t, hashType, tool.GetHashType())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashTool_CompareHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
data := "test data"
|
||||
|
||||
// Generate hash
|
||||
hash, err := tool.HashString(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
expectedHash string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "匹配的哈希值",
|
||||
data: data,
|
||||
expectedHash: hash,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "大小写不同但内容相同",
|
||||
data: data,
|
||||
expectedHash: strings.ToUpper(hash),
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "不匹配的哈希值",
|
||||
data: data,
|
||||
expectedHash: "invalid_hash",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "不同的数据",
|
||||
data: "different data",
|
||||
expectedHash: hash,
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
match, err := tool.CompareHash(tt.data, tt.expectedHash)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.shouldMatch, match)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashTool_CompareFileHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temporary file
|
||||
tmpFile := t.TempDir() + "/test.txt"
|
||||
content := []byte("test file content")
|
||||
err := os.WriteFile(tmpFile, content, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool := model.NewHashTool(model.SHA256)
|
||||
ctx := context.Background()
|
||||
|
||||
// Generate expected hash
|
||||
expectedHash, err := tool.HashFile(ctx, tmpFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filePath string
|
||||
expectedHash string
|
||||
shouldMatch bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "匹配的文件哈希",
|
||||
filePath: tmpFile,
|
||||
expectedHash: expectedHash,
|
||||
shouldMatch: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "大小写不同但内容相同",
|
||||
filePath: tmpFile,
|
||||
expectedHash: strings.ToUpper(expectedHash),
|
||||
shouldMatch: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "不匹配的文件哈希",
|
||||
filePath: tmpFile,
|
||||
expectedHash: "invalid_hash",
|
||||
shouldMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "文件不存在",
|
||||
filePath: "/nonexistent/file",
|
||||
expectedHash: expectedHash,
|
||||
shouldMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
match, err := tool.CompareFileHash(ctx, tt.filePath, tt.expectedHash)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.shouldMatch, match)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashList_GetHashType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock hash data
|
||||
mockHash := &mockHashData{
|
||||
key: "test-key",
|
||||
hash: "test-hash",
|
||||
hashType: model.SHA256,
|
||||
}
|
||||
|
||||
hashList := model.HashList{mockHash}
|
||||
assert.Equal(t, model.SHA256, hashList.GetHashType())
|
||||
}
|
||||
|
||||
// mockHashData implements HashData interface for testing.
|
||||
type mockHashData struct {
|
||||
key string
|
||||
hash string
|
||||
hashType model.HashType
|
||||
}
|
||||
|
||||
func (m *mockHashData) Key() string {
|
||||
return m.key
|
||||
}
|
||||
|
||||
func (m *mockHashData) Hash() string {
|
||||
return m.hash
|
||||
}
|
||||
|
||||
func (m *mockHashData) Type() model.HashType {
|
||||
return m.hashType
|
||||
}
|
||||
577
api/model/operation.go
Normal file
577
api/model/operation.go
Normal file
@@ -0,0 +1,577 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
//
|
||||
// ===== 操作来源类型 =====
|
||||
//
|
||||
|
||||
// Source 表示操作来源,用于区分不同系统模块(IRP、DOIP)。
|
||||
type Source string
|
||||
|
||||
const (
|
||||
OpSourceIRP Source = "IRP"
|
||||
OpSourceDOIP Source = "DOIP"
|
||||
)
|
||||
|
||||
//
|
||||
// ===== 操作类型枚举 =====
|
||||
//
|
||||
|
||||
// Type 表示操作的具体类型。
|
||||
type Type string
|
||||
|
||||
// DOIP 操作类型枚举。
|
||||
const (
|
||||
OpTypeHello Type = "Hello"
|
||||
OpTypeRetrieve Type = "Retrieve"
|
||||
OpTypeCreate Type = "Create"
|
||||
OpTypeDelete Type = "Delete"
|
||||
OpTypeUpdate Type = "Update"
|
||||
OpTypeSearch Type = "Search"
|
||||
OpTypeListOperations Type = "ListOperations"
|
||||
)
|
||||
|
||||
// IRP 操作类型枚举。
|
||||
const (
|
||||
OpTypeOCReserved Type = "OC_RESERVED"
|
||||
OpTypeOCResolution Type = "OC_RESOLUTION"
|
||||
OpTypeOCGetSiteInfo Type = "OC_GET_SITEINFO"
|
||||
OpTypeOCCreateHandle Type = "OC_CREATE_HANDLE"
|
||||
OpTypeOCDeleteHandle Type = "OC_DELETE_HANDLE"
|
||||
OpTypeOCAddValue Type = "OC_ADD_VALUE"
|
||||
OpTypeOCRemoveValue Type = "OC_REMOVE_VALUE"
|
||||
OpTypeOCModifyValue Type = "OC_MODIFY_VALUE"
|
||||
OpTypeOCListHandle Type = "OC_LIST_HANDLE"
|
||||
OpTypeOCListNA Type = "OC_LIST_NA"
|
||||
OpTypeOCResolutionDOID Type = "OC_RESOLUTION_DOID"
|
||||
OpTypeOCCreateDOID Type = "OC_CREATE_DOID"
|
||||
OpTypeOCDeleteDOID Type = "OC_DELETE_DOID"
|
||||
OpTypeOCUpdateDOID Type = "OC_UPDATE_DOID"
|
||||
OpTypeOCBatchCreateDOID Type = "OC_BATCH_CREATE_DOID"
|
||||
OpTypeOCResolutionDOIDRecursive Type = "OC_RESOLUTION_DOID_RECURSIVE"
|
||||
OpTypeOCGetUsers Type = "OC_GET_USERS"
|
||||
OpTypeOCGetRepos Type = "OC_GET_REPOS"
|
||||
OpTypeOCVerifyIRS Type = "OC_VERIFY_IRS"
|
||||
OpTypeOCResolveGRS Type = "OC_RESOLVE_GRS"
|
||||
OpTypeOCCreateOrgGRS Type = "OC_CREATE_ORG_GRS"
|
||||
OpTypeOCUpdateOrgGRS Type = "OC_UPDATE_ORG_GRS"
|
||||
OpTypeOCDeleteOrgGRS Type = "OC_DELETE_ORG_GRS"
|
||||
OpTypeOCSyncOrgIRSParent Type = "OC_SYNC_ORG_IRS_PARENT"
|
||||
OpTypeOCUpdateOrgIRSParent Type = "OC_UPDATE_ORG_IRS_PARENT"
|
||||
OpTypeOCDeleteOrgIRSParent Type = "OC_DELETE_ORG_IRS_PARENT"
|
||||
OpTypeOCChallengeResponse Type = "OC_CHALLENGE_RESPONSE"
|
||||
OpTypeOCVerifyChallenge Type = "OC_VERIFY_CHALLENGE"
|
||||
OpTypeOCSessionSetup Type = "OC_SESSION_SETUP"
|
||||
OpTypeOCSessionTerminate Type = "OC_SESSION_TERMINATE"
|
||||
OpTypeOCSessionExchangeKey Type = "OC_SESSION_EXCHANGEKEY"
|
||||
OpTypeOCVerifyRouter Type = "OC_VERIFY_ROUTER"
|
||||
OpTypeOCQueryRouter Type = "OC_QUERY_ROUTER"
|
||||
)
|
||||
|
||||
//
|
||||
// ===== 操作类型检索工具 =====
|
||||
//
|
||||
|
||||
// allOpTypes 存储不同来源的操作类型列表,用于快速查找和验证。
|
||||
//
|
||||
//nolint:gochecknoglobals // 全局常量映射用于操作类型查找
|
||||
var allOpTypes = map[Source][]Type{
|
||||
OpSourceDOIP: {
|
||||
OpTypeHello, OpTypeRetrieve, OpTypeCreate,
|
||||
OpTypeDelete, OpTypeUpdate, OpTypeSearch,
|
||||
OpTypeListOperations,
|
||||
},
|
||||
OpSourceIRP: {
|
||||
OpTypeOCReserved, OpTypeOCResolution, OpTypeOCGetSiteInfo,
|
||||
OpTypeOCCreateHandle, OpTypeOCDeleteHandle, OpTypeOCAddValue,
|
||||
OpTypeOCRemoveValue, OpTypeOCModifyValue, OpTypeOCListHandle,
|
||||
OpTypeOCListNA, OpTypeOCResolutionDOID, OpTypeOCCreateDOID,
|
||||
OpTypeOCDeleteDOID, OpTypeOCUpdateDOID, OpTypeOCBatchCreateDOID,
|
||||
OpTypeOCResolutionDOIDRecursive, OpTypeOCGetUsers, OpTypeOCGetRepos,
|
||||
OpTypeOCVerifyIRS, OpTypeOCResolveGRS, OpTypeOCCreateOrgGRS,
|
||||
OpTypeOCUpdateOrgGRS, OpTypeOCDeleteOrgGRS, OpTypeOCSyncOrgIRSParent,
|
||||
OpTypeOCUpdateOrgIRSParent, OpTypeOCDeleteOrgIRSParent,
|
||||
OpTypeOCChallengeResponse, OpTypeOCVerifyChallenge,
|
||||
OpTypeOCSessionSetup, OpTypeOCSessionTerminate,
|
||||
OpTypeOCSessionExchangeKey, OpTypeOCVerifyRouter, OpTypeOCQueryRouter,
|
||||
},
|
||||
}
|
||||
|
||||
// GetOpTypesBySource 返回指定来源的可用操作类型列表。
|
||||
func GetOpTypesBySource(source Source) []Type {
|
||||
return allOpTypes[source]
|
||||
}
|
||||
|
||||
// IsValidOpType 判断指定操作类型在给定来源下是否合法。
|
||||
func IsValidOpType(source Source, opType Type) bool {
|
||||
for _, t := range GetOpTypesBySource(source) {
|
||||
if t == opType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 操作记录结构 =====
|
||||
//
|
||||
|
||||
// Operation 表示一次完整的操作记录。
|
||||
// 用于记录系统中的操作行为,包含操作元数据、数据标识、操作者信息以及请求/响应的哈希值。
|
||||
type Operation struct {
|
||||
OpID string `json:"opId" validate:"max=32"`
|
||||
Timestamp time.Time `json:"timestamp" validate:"required"`
|
||||
OpSource Source `json:"opSource" validate:"required,oneof=IRP DOIP"`
|
||||
OpType Type `json:"opType" validate:"required"`
|
||||
DoPrefix string `json:"doPrefix" validate:"required,max=512"`
|
||||
DoRepository string `json:"doRepository" validate:"required,max=512"`
|
||||
Doid string `json:"doid" validate:"required,max=512"`
|
||||
ProducerID string `json:"producerId" validate:"required,max=512"`
|
||||
OpActor string `json:"opActor" validate:"max=64"`
|
||||
RequestBodyHash *string `json:"requestBodyHash" validate:"omitempty,max=128"`
|
||||
ResponseBodyHash *string `json:"responseBodyHash" validate:"omitempty,max=128"`
|
||||
Ack func() bool `json:"-"`
|
||||
Nack func() bool `json:"-"`
|
||||
binary []byte
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 构造函数 =====
|
||||
//
|
||||
|
||||
// NewFullOperation 创建包含所有字段的完整 Operation。
|
||||
// 自动完成哈希计算和字段校验,确保创建的 Operation 是完整且有效的。
|
||||
func NewFullOperation(
|
||||
opSource Source,
|
||||
opType Type,
|
||||
doPrefix, doRepository, doid string,
|
||||
producerID string,
|
||||
opActor string,
|
||||
requestBody, responseBody interface{},
|
||||
timestamp time.Time,
|
||||
) (*Operation, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new full operation",
|
||||
"opSource", opSource,
|
||||
"opType", opType,
|
||||
"doPrefix", doPrefix,
|
||||
"doRepository", doRepository,
|
||||
"doid", doid,
|
||||
"producerID", producerID,
|
||||
"opActor", opActor,
|
||||
)
|
||||
op := &Operation{
|
||||
Timestamp: timestamp,
|
||||
OpSource: opSource,
|
||||
OpType: opType,
|
||||
DoPrefix: doPrefix,
|
||||
DoRepository: doRepository,
|
||||
Doid: doid,
|
||||
ProducerID: producerID,
|
||||
OpActor: opActor,
|
||||
}
|
||||
|
||||
log.Debug("Setting request body hash")
|
||||
if err := op.RequestBodyFlexible(requestBody); err != nil {
|
||||
log.Error("Failed to set request body hash",
|
||||
"error", err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("Setting response body hash")
|
||||
if err := op.ResponseBodyFlexible(responseBody); err != nil {
|
||||
log.Error("Failed to set response body hash",
|
||||
"error", err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("Checking and initializing operation")
|
||||
if err := op.CheckAndInit(); err != nil {
|
||||
log.Error("Failed to check and init operation",
|
||||
"error", err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("Full operation created successfully",
|
||||
"opID", op.OpID,
|
||||
)
|
||||
return op, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 接口实现 =====
|
||||
//
|
||||
|
||||
func (o *Operation) Key() string {
|
||||
return o.OpID
|
||||
}
|
||||
|
||||
// OperationHashData 实现 HashData 接口,用于存储 Operation 的哈希计算结果。
|
||||
type OperationHashData struct {
|
||||
key string
|
||||
hash string
|
||||
}
|
||||
|
||||
func (o OperationHashData) Key() string {
|
||||
return o.key
|
||||
}
|
||||
|
||||
func (o OperationHashData) Hash() string {
|
||||
return o.hash
|
||||
}
|
||||
|
||||
func (o OperationHashData) Type() HashType {
|
||||
return Sha256Simd
|
||||
}
|
||||
|
||||
// DoHash 计算 Operation 的整体哈希值,用于数据完整性验证。
|
||||
// 哈希基于序列化后的二进制数据计算,确保操作记录的不可篡改性。
|
||||
func (o *Operation) DoHash(_ context.Context) (HashData, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Computing hash for operation",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
hashTool := GetHashTool(Sha256Simd)
|
||||
binary, err := o.MarshalBinary()
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal operation for hash",
|
||||
"error", err,
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to marshal operation: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Computing hash bytes",
|
||||
"opID", o.OpID,
|
||||
"binaryLength", len(binary),
|
||||
)
|
||||
hash, err := hashTool.HashBytes(binary)
|
||||
if err != nil {
|
||||
log.Error("Failed to compute hash",
|
||||
"error", err,
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to compute hash: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Hash computed successfully",
|
||||
"opID", o.OpID,
|
||||
"hash", hash,
|
||||
)
|
||||
return OperationHashData{
|
||||
key: o.OpID,
|
||||
hash: hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== CBOR 序列化相关 =====
|
||||
//
|
||||
|
||||
// operationData 用于 CBOR 序列化/反序列化的中间结构。
|
||||
// 排除函数字段和缓存字段,仅包含可序列化的数据字段。
|
||||
type operationData struct {
|
||||
OpID *string `cbor:"opId"`
|
||||
Timestamp *time.Time `cbor:"timestamp"`
|
||||
OpSource *Source `cbor:"opSource"`
|
||||
OpType *Type `cbor:"opType"`
|
||||
DoPrefix *string `cbor:"doPrefix"`
|
||||
DoRepository *string `cbor:"doRepository"`
|
||||
Doid *string `cbor:"doid"`
|
||||
ProducerID *string `cbor:"producerId"`
|
||||
OpActor *string `cbor:"opActor"`
|
||||
RequestBodyHash *string `cbor:"requestBodyHash"`
|
||||
ResponseBodyHash *string `cbor:"responseBodyHash"`
|
||||
}
|
||||
|
||||
// toOperationData 将 Operation 转换为 operationData,用于序列化。
|
||||
func (o *Operation) toOperationData() *operationData {
|
||||
return &operationData{
|
||||
OpID: &o.OpID,
|
||||
Timestamp: &o.Timestamp,
|
||||
OpSource: &o.OpSource,
|
||||
OpType: &o.OpType,
|
||||
DoPrefix: &o.DoPrefix,
|
||||
DoRepository: &o.DoRepository,
|
||||
Doid: &o.Doid,
|
||||
ProducerID: &o.ProducerID,
|
||||
OpActor: &o.OpActor,
|
||||
RequestBodyHash: o.RequestBodyHash,
|
||||
ResponseBodyHash: o.ResponseBodyHash,
|
||||
}
|
||||
}
|
||||
|
||||
// fromOperationData 从 operationData 填充 Operation,用于反序列化。
|
||||
func (o *Operation) fromOperationData(opData *operationData) {
|
||||
if opData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if opData.OpID != nil {
|
||||
o.OpID = *opData.OpID
|
||||
}
|
||||
if opData.Timestamp != nil {
|
||||
o.Timestamp = *opData.Timestamp
|
||||
}
|
||||
if opData.OpSource != nil {
|
||||
o.OpSource = *opData.OpSource
|
||||
}
|
||||
if opData.OpType != nil {
|
||||
o.OpType = *opData.OpType
|
||||
}
|
||||
if opData.DoPrefix != nil {
|
||||
o.DoPrefix = *opData.DoPrefix
|
||||
}
|
||||
if opData.DoRepository != nil {
|
||||
o.DoRepository = *opData.DoRepository
|
||||
}
|
||||
if opData.Doid != nil {
|
||||
o.Doid = *opData.Doid
|
||||
}
|
||||
if opData.ProducerID != nil {
|
||||
o.ProducerID = *opData.ProducerID
|
||||
}
|
||||
if opData.OpActor != nil {
|
||||
o.OpActor = *opData.OpActor
|
||||
}
|
||||
if opData.RequestBodyHash != nil {
|
||||
hash := *opData.RequestBodyHash
|
||||
o.RequestBodyHash = &hash
|
||||
}
|
||||
if opData.ResponseBodyHash != nil {
|
||||
hash := *opData.ResponseBodyHash
|
||||
o.ResponseBodyHash = &hash
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalBinary 将 Operation 序列化为 CBOR 格式的二进制数据。
|
||||
// 实现 encoding.BinaryMarshaler 接口。
|
||||
// 使用 Canonical CBOR 编码确保序列化结果的一致性,使用缓存机制避免重复序列化。
|
||||
func (o *Operation) MarshalBinary() ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Marshaling operation to CBOR binary",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
if o.binary != nil {
|
||||
log.Debug("Using cached binary data",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return o.binary, nil
|
||||
}
|
||||
|
||||
opData := o.toOperationData()
|
||||
|
||||
log.Debug("Marshaling operation data to canonical CBOR",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
binary, err := helpers.MarshalCanonical(opData)
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal operation to CBOR",
|
||||
"error", err,
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to marshal operation to CBOR: %w", err)
|
||||
}
|
||||
|
||||
o.binary = binary
|
||||
|
||||
log.Debug("Operation marshaled successfully",
|
||||
"opID", o.OpID,
|
||||
"binaryLength", len(binary),
|
||||
)
|
||||
return binary, nil
|
||||
}
|
||||
|
||||
// GetProducerID 返回 ProducerID,实现 Trustlog 接口。
|
||||
func (o *Operation) GetProducerID() string {
|
||||
return o.ProducerID
|
||||
}
|
||||
|
||||
// UnmarshalBinary 从 CBOR 格式的二进制数据反序列化为 Operation。
|
||||
// 实现 encoding.BinaryUnmarshaler 接口。
|
||||
func (o *Operation) UnmarshalBinary(data []byte) error {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Unmarshaling operation from CBOR binary",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
if len(data) == 0 {
|
||||
log.Error("Data is empty")
|
||||
return errors.New("data is empty")
|
||||
}
|
||||
|
||||
opData := &operationData{}
|
||||
|
||||
log.Debug("Unmarshaling operation data from CBOR")
|
||||
if err := helpers.Unmarshal(data, opData); err != nil {
|
||||
log.Error("Failed to unmarshal operation from CBOR",
|
||||
"error", err,
|
||||
)
|
||||
return fmt.Errorf("failed to unmarshal operation from CBOR: %w", err)
|
||||
}
|
||||
|
||||
o.fromOperationData(opData)
|
||||
|
||||
o.binary = data
|
||||
|
||||
log.Debug("Operation unmarshaled successfully",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 哈希设置方法 =====
|
||||
//
|
||||
|
||||
// setBodyHashFlexible 根据输入数据类型计算哈希,支持 string 和 []byte。
|
||||
// 使用固定的 Sha256Simd 算法。
|
||||
// 如果输入为 nil 或空,则目标指针设置为 nil,表示该字段未设置。
|
||||
func (o *Operation) setBodyHashFlexible(data interface{}, target **string) error {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Setting body hash flexible",
|
||||
"opID", o.OpID,
|
||||
"dataType", fmt.Sprintf("%T", data),
|
||||
)
|
||||
if data == nil {
|
||||
log.Debug("Data is nil, setting target to nil")
|
||||
*target = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
hashTool := GetHashTool(Sha256Simd)
|
||||
var raw []byte
|
||||
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
log.Debug("String data is empty, setting target to nil")
|
||||
*target = nil
|
||||
return nil
|
||||
}
|
||||
raw = []byte(v)
|
||||
log.Debug("Converting string to bytes",
|
||||
"stringLength", len(v),
|
||||
)
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
log.Debug("Byte data is empty, setting target to nil")
|
||||
*target = nil
|
||||
return nil
|
||||
}
|
||||
raw = v
|
||||
log.Debug("Using byte data directly",
|
||||
"byteLength", len(v),
|
||||
)
|
||||
default:
|
||||
log.Error("Unsupported data type",
|
||||
"dataType", fmt.Sprintf("%T", v),
|
||||
)
|
||||
return fmt.Errorf("unsupported data type %T", v)
|
||||
}
|
||||
|
||||
log.Debug("Computing hash for body data",
|
||||
"dataLength", len(raw),
|
||||
)
|
||||
hash, err := hashTool.HashBytes(raw)
|
||||
if err != nil {
|
||||
log.Error("Failed to compute hash",
|
||||
"error", err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
*target = &hash
|
||||
log.Debug("Body hash set successfully",
|
||||
"hash", hash,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestBodyFlexible 设置请求体哈希值。
|
||||
// 支持 string 和 []byte 类型,nil 或空值会将 RequestBodyHash 设置为 nil。
|
||||
func (o *Operation) RequestBodyFlexible(data interface{}) error {
|
||||
return o.setBodyHashFlexible(data, &o.RequestBodyHash)
|
||||
}
|
||||
|
||||
// ResponseBodyFlexible 设置响应体哈希值。
|
||||
// 支持 string 和 []byte 类型,nil 或空值会将 ResponseBodyHash 设置为 nil。
|
||||
func (o *Operation) ResponseBodyFlexible(data interface{}) error {
|
||||
return o.setBodyHashFlexible(data, &o.ResponseBodyHash)
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 链式调用支持 =====
|
||||
//
|
||||
|
||||
// WithRequestBody 设置请求体哈希并返回自身,支持链式调用。
|
||||
func (o *Operation) WithRequestBody(data []byte) *Operation {
|
||||
_ = o.RequestBodyFlexible(data)
|
||||
return o
|
||||
}
|
||||
|
||||
// WithResponseBody 设置响应体哈希并返回自身,支持链式调用。
|
||||
func (o *Operation) WithResponseBody(data []byte) *Operation {
|
||||
_ = o.ResponseBodyFlexible(data)
|
||||
return o
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 初始化与验证 =====
|
||||
//
|
||||
|
||||
// CheckAndInit 校验并初始化 Operation。
|
||||
// 自动填充缺失字段(OpID、OpActor),执行业务逻辑验证(doid 格式),
|
||||
// 字段非空验证由 validate 标签处理。
|
||||
func (o *Operation) CheckAndInit() error {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Checking and initializing operation",
|
||||
"opSource", o.OpSource,
|
||||
"opType", o.OpType,
|
||||
"doid", o.Doid,
|
||||
)
|
||||
if o.OpID == "" {
|
||||
o.OpID = helpers.NewUUIDv7()
|
||||
log.Debug("Generated new OpID",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
}
|
||||
|
||||
if o.OpActor == "" {
|
||||
o.OpActor = "SYSTEM"
|
||||
log.Debug("Set default OpActor to SYSTEM")
|
||||
}
|
||||
|
||||
expectedPrefix := fmt.Sprintf("%s/%s", o.DoPrefix, o.DoRepository)
|
||||
if !strings.HasPrefix(o.Doid, expectedPrefix) {
|
||||
log.Error("Doid format validation failed",
|
||||
"doid", o.Doid,
|
||||
"expectedPrefix", expectedPrefix,
|
||||
)
|
||||
return fmt.Errorf("doid must start with '%s'", expectedPrefix)
|
||||
}
|
||||
|
||||
log.Debug("Validating operation struct")
|
||||
if err := helpers.GetValidator().Struct(o); err != nil {
|
||||
log.Error("Operation validation failed",
|
||||
"error", err,
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Operation checked and initialized successfully",
|
||||
"opID", o.OpID,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
593
api/model/operation_test.go
Normal file
593
api/model/operation_test.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestOperation_Key(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
OpID: "test-op-id",
|
||||
}
|
||||
assert.Equal(t, "test-op-id", op.Key())
|
||||
}
|
||||
|
||||
func TestOperation_CheckAndInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
op *model.Operation
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid operation",
|
||||
op: &model.Operation{
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "auto generate OpID",
|
||||
op: &model.Operation{
|
||||
OpID: "", // Will be auto-generated
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "auto set OpActor",
|
||||
op: &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "", // Will be set to "SYSTEM"
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid doid format",
|
||||
op: &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "invalid/123", // Doesn't start with "test/repo"
|
||||
ProducerID: "producer-1",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tt.op.CheckAndInit()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if tt.name == "auto generate OpID" {
|
||||
assert.NotEmpty(t, tt.op.OpID)
|
||||
}
|
||||
if tt.name == "auto set OpActor" {
|
||||
assert.Equal(t, "SYSTEM", tt.op.OpActor)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperation_RequestBodyFlexible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
input: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
input: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty bytes",
|
||||
input: []byte{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
input: 123,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
op := &model.Operation{}
|
||||
err := op.RequestBodyFlexible(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperation_ResponseBodyFlexible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
input: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
input: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
op := &model.Operation{}
|
||||
err := op.ResponseBodyFlexible(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperation_WithRequestBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{}
|
||||
result := op.WithRequestBody([]byte("test"))
|
||||
assert.Equal(t, op, result)
|
||||
}
|
||||
|
||||
func TestOperation_WithResponseBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{}
|
||||
result := op.WithResponseBody([]byte("test"))
|
||||
assert.Equal(t, op, result)
|
||||
}
|
||||
|
||||
func TestOperation_MarshalUnmarshalBinary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
// Marshal
|
||||
data, err := original.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// Unmarshal
|
||||
result := &model.Operation{}
|
||||
err = result.UnmarshalBinary(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
assert.Equal(t, original.OpID, result.OpID)
|
||||
assert.Equal(t, original.OpSource, result.OpSource)
|
||||
assert.Equal(t, original.OpType, result.OpType)
|
||||
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||
assert.Equal(t, original.DoRepository, result.DoRepository)
|
||||
assert.Equal(t, original.Doid, result.Doid)
|
||||
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||
assert.Equal(t, original.OpActor, result.OpActor)
|
||||
// 验证纳秒精度被保留
|
||||
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||
"时间戳的纳秒精度应该被保留")
|
||||
}
|
||||
|
||||
func TestOperation_MarshalBinary_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
}
|
||||
// MarshalBinary should succeed even without CheckAndInit
|
||||
// It just serializes the data
|
||||
data, err := op.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
|
||||
func TestOperation_UnmarshalBinary_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{}
|
||||
err := op.UnmarshalBinary([]byte{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOperation_GetProducerID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
ProducerID: "producer-123",
|
||||
}
|
||||
assert.Equal(t, "producer-123", op.GetProducerID())
|
||||
}
|
||||
|
||||
func TestOperation_DoHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
hashData, err := op.DoHash(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, hashData)
|
||||
assert.Equal(t, op.OpID, hashData.Key())
|
||||
assert.NotEmpty(t, hashData.Hash())
|
||||
}
|
||||
|
||||
func TestOperationHashData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// OperationHashData is created through DoHash, test it indirectly
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
hashData, err := op.DoHash(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, hashData)
|
||||
assert.Equal(t, "op-123", hashData.Key())
|
||||
assert.NotEmpty(t, hashData.Hash())
|
||||
assert.Equal(t, model.Sha256Simd, hashData.Type())
|
||||
}
|
||||
|
||||
func TestOperation_UnmarshalBinary_InvalidData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{}
|
||||
err := op.UnmarshalBinary([]byte("invalid-cbor-data"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal operation from CBOR")
|
||||
}
|
||||
|
||||
func TestOperation_MarshalTrustlog_EmptyProducerID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create an operation with empty ProducerID
|
||||
// MarshalBinary will fail validation, but MarshalTrustlog checks ProducerID first
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "", // Empty ProducerID
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
config := model.NewEnvelopeConfig(model.NewNopSigner())
|
||||
_, err := model.MarshalTrustlog(op, config)
|
||||
// MarshalTrustlog checks ProducerID before calling MarshalBinary
|
||||
require.Error(t, err)
|
||||
// Error could be from ProducerID check or MarshalBinary validation
|
||||
assert.True(t,
|
||||
err.Error() == "producerID cannot be empty" ||
|
||||
strings.Contains(err.Error(), "ProducerID") ||
|
||||
strings.Contains(err.Error(), "producerID"))
|
||||
}
|
||||
|
||||
func TestOperation_MarshalTrustlog_NilSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
op := &model.Operation{
|
||||
OpID: "op-123",
|
||||
Timestamp: time.Now(),
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := op.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := model.EnvelopeConfig{Signer: nil}
|
||||
_, err = model.MarshalTrustlog(op, config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "signer is required")
|
||||
}
|
||||
|
||||
func TestGetOpTypesBySource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
source model.Source
|
||||
wantTypes []model.Type
|
||||
}{
|
||||
{
|
||||
name: "IRP操作类型",
|
||||
source: model.OpSourceIRP,
|
||||
wantTypes: []model.Type{
|
||||
model.OpTypeOCCreateHandle,
|
||||
model.OpTypeOCDeleteHandle,
|
||||
model.OpTypeOCAddValue,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DOIP操作类型",
|
||||
source: model.OpSourceDOIP,
|
||||
wantTypes: []model.Type{
|
||||
model.OpTypeHello,
|
||||
model.OpTypeCreate,
|
||||
model.OpTypeDelete,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
opTypes := model.GetOpTypesBySource(tt.source)
|
||||
assert.NotNil(t, opTypes)
|
||||
// Verify expected types are included
|
||||
for _, expectedType := range tt.wantTypes {
|
||||
assert.Contains(t, opTypes, expectedType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidOpType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
source model.Source
|
||||
opType model.Type
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "IRP有效操作类型",
|
||||
source: model.OpSourceIRP,
|
||||
opType: model.OpTypeOCCreateHandle,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IRP无效操作类型",
|
||||
source: model.OpSourceIRP,
|
||||
opType: model.OpTypeHello,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "DOIP有效操作类型",
|
||||
source: model.OpSourceDOIP,
|
||||
opType: model.OpTypeHello,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "DOIP无效操作类型",
|
||||
source: model.OpSourceDOIP,
|
||||
opType: model.OpTypeOCCreateHandle,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "未知来源和类型",
|
||||
source: model.Source("unknown"),
|
||||
opType: model.Type("unknown"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := model.IsValidOpType(tt.source, tt.opType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFullOperation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opSource model.Source
|
||||
opType model.Type
|
||||
doPrefix string
|
||||
doRepository string
|
||||
doid string
|
||||
producerID string
|
||||
opActor string
|
||||
requestBody interface{}
|
||||
responseBody interface{}
|
||||
timestamp time.Time
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "成功创建完整操作",
|
||||
opSource: model.OpSourceIRP,
|
||||
opType: model.OpTypeOCCreateHandle,
|
||||
doPrefix: "test",
|
||||
doRepository: "repo",
|
||||
doid: "test/repo/123",
|
||||
producerID: "producer-1",
|
||||
opActor: "actor-1",
|
||||
requestBody: []byte(`{"key": "value"}`),
|
||||
responseBody: []byte(`{"status": "ok"}`),
|
||||
timestamp: time.Now(),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空请求体和响应体",
|
||||
opSource: model.OpSourceIRP,
|
||||
opType: model.OpTypeOCCreateHandle,
|
||||
doPrefix: "test",
|
||||
doRepository: "repo",
|
||||
doid: "test/repo/123",
|
||||
producerID: "producer-1",
|
||||
opActor: "actor-1",
|
||||
requestBody: nil,
|
||||
responseBody: nil,
|
||||
timestamp: time.Now(),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "字符串类型的请求体",
|
||||
opSource: model.OpSourceIRP,
|
||||
opType: model.OpTypeOCCreateHandle,
|
||||
doPrefix: "test",
|
||||
doRepository: "repo",
|
||||
doid: "test/repo/123",
|
||||
producerID: "producer-1",
|
||||
opActor: "actor-1",
|
||||
requestBody: "string body",
|
||||
responseBody: "string response",
|
||||
timestamp: time.Now(),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
op, err := model.NewFullOperation(
|
||||
tt.opSource,
|
||||
tt.opType,
|
||||
tt.doPrefix,
|
||||
tt.doRepository,
|
||||
tt.doid,
|
||||
tt.producerID,
|
||||
tt.opActor,
|
||||
tt.requestBody,
|
||||
tt.responseBody,
|
||||
tt.timestamp,
|
||||
)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, op)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, op)
|
||||
assert.Equal(t, tt.opSource, op.OpSource)
|
||||
assert.Equal(t, tt.opType, op.OpType)
|
||||
assert.Equal(t, tt.doPrefix, op.DoPrefix)
|
||||
assert.Equal(t, tt.doRepository, op.DoRepository)
|
||||
assert.Equal(t, tt.doid, op.Doid)
|
||||
assert.Equal(t, tt.producerID, op.ProducerID)
|
||||
assert.Equal(t, tt.opActor, op.OpActor)
|
||||
assert.NotEmpty(t, op.OpID) // Should be auto-generated
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
56
api/model/operation_timestamp_test.go
Normal file
56
api/model/operation_timestamp_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
// TestOperation_TimestampNanosecondPrecision 验证 Operation 的时间戳在 CBOR 序列化/反序列化后能保留纳秒精度
|
||||
func TestOperation_TimestampNanosecondPrecision(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 创建一个包含纳秒精度的时间戳
|
||||
timestamp := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||
|
||||
original := &model.Operation{
|
||||
OpID: "op-nanosecond-test",
|
||||
Timestamp: timestamp,
|
||||
OpSource: model.OpSourceIRP,
|
||||
OpType: model.OpTypeOCCreateHandle,
|
||||
DoPrefix: "test",
|
||||
DoRepository: "repo",
|
||||
Doid: "test/repo/123",
|
||||
ProducerID: "producer-1",
|
||||
OpActor: "actor-1",
|
||||
}
|
||||
|
||||
err := original.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Original timestamp: %v", original.Timestamp)
|
||||
t.Logf("Original nanoseconds: %d", original.Timestamp.Nanosecond())
|
||||
|
||||
// 序列化
|
||||
data, err := original.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// 反序列化
|
||||
result := &model.Operation{}
|
||||
err = result.UnmarshalBinary(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Decoded timestamp: %v", result.Timestamp)
|
||||
t.Logf("Decoded nanoseconds: %d", result.Timestamp.Nanosecond())
|
||||
|
||||
// 验证纳秒精度被完整保留
|
||||
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||
"时间戳的纳秒精度应该被完整保留")
|
||||
assert.Equal(t, original.Timestamp.Nanosecond(), result.Timestamp.Nanosecond(),
|
||||
"纳秒部分应该相等")
|
||||
}
|
||||
146
api/model/proof.go
Normal file
146
api/model/proof.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||
)
|
||||
|
||||
// MerkleTreeProofItem 表示Merkle树证明项.
|
||||
type MerkleTreeProofItem struct {
|
||||
Floor uint32 // 层级
|
||||
Hash string // 哈希值
|
||||
Left bool // 是否为左节点
|
||||
}
|
||||
|
||||
// Proof 表示取证证明.
|
||||
type Proof struct {
|
||||
ColItems []*MerkleTreeProofItem // 集合项证明
|
||||
RawItems []*MerkleTreeProofItem // 原始项证明
|
||||
ColRootItem []*MerkleTreeProofItem // 集合根项证明
|
||||
RawRootItem []*MerkleTreeProofItem // 原始根项证明
|
||||
Sign string // 签名
|
||||
Version string // 版本号
|
||||
}
|
||||
|
||||
// ProofFromProtobuf 将protobuf的Proof转换为model.Proof.
|
||||
func ProofFromProtobuf(pbProof *pb.Proof) *Proof {
|
||||
if pbProof == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
proof := &Proof{
|
||||
Sign: pbProof.GetSign(),
|
||||
Version: pbProof.GetVersion(),
|
||||
}
|
||||
|
||||
// 转换 ColItems
|
||||
if pbColItems := pbProof.GetColItems(); len(pbColItems) > 0 {
|
||||
proof.ColItems = make([]*MerkleTreeProofItem, 0, len(pbColItems))
|
||||
for _, item := range pbColItems {
|
||||
proof.ColItems = append(proof.ColItems, &MerkleTreeProofItem{
|
||||
Floor: item.GetFloor(),
|
||||
Hash: item.GetHash(),
|
||||
Left: item.GetLeft(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 RawItems
|
||||
if pbRawItems := pbProof.GetRawItems(); len(pbRawItems) > 0 {
|
||||
proof.RawItems = make([]*MerkleTreeProofItem, 0, len(pbRawItems))
|
||||
for _, item := range pbRawItems {
|
||||
proof.RawItems = append(proof.RawItems, &MerkleTreeProofItem{
|
||||
Floor: item.GetFloor(),
|
||||
Hash: item.GetHash(),
|
||||
Left: item.GetLeft(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 ColRootItem
|
||||
if pbColRootItem := pbProof.GetColRootItem(); len(pbColRootItem) > 0 {
|
||||
proof.ColRootItem = make([]*MerkleTreeProofItem, 0, len(pbColRootItem))
|
||||
for _, item := range pbColRootItem {
|
||||
proof.ColRootItem = append(proof.ColRootItem, &MerkleTreeProofItem{
|
||||
Floor: item.GetFloor(),
|
||||
Hash: item.GetHash(),
|
||||
Left: item.GetLeft(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 RawRootItem
|
||||
if pbRawRootItem := pbProof.GetRawRootItem(); len(pbRawRootItem) > 0 {
|
||||
proof.RawRootItem = make([]*MerkleTreeProofItem, 0, len(pbRawRootItem))
|
||||
for _, item := range pbRawRootItem {
|
||||
proof.RawRootItem = append(proof.RawRootItem, &MerkleTreeProofItem{
|
||||
Floor: item.GetFloor(),
|
||||
Hash: item.GetHash(),
|
||||
Left: item.GetLeft(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return proof
|
||||
}
|
||||
|
||||
// ProofToProtobuf 将model.Proof转换为protobuf的Proof.
|
||||
func ProofToProtobuf(proof *Proof) *pb.Proof {
|
||||
if proof == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
Sign: proof.Sign,
|
||||
Version: proof.Version,
|
||||
}
|
||||
|
||||
// 转换 ColItems
|
||||
if len(proof.ColItems) > 0 {
|
||||
pbProof.ColItems = make([]*pb.MerkleTreeProofItem, 0, len(proof.ColItems))
|
||||
for _, item := range proof.ColItems {
|
||||
pbProof.ColItems = append(pbProof.ColItems, &pb.MerkleTreeProofItem{
|
||||
Floor: item.Floor,
|
||||
Hash: item.Hash,
|
||||
Left: item.Left,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 RawItems
|
||||
if len(proof.RawItems) > 0 {
|
||||
pbProof.RawItems = make([]*pb.MerkleTreeProofItem, 0, len(proof.RawItems))
|
||||
for _, item := range proof.RawItems {
|
||||
pbProof.RawItems = append(pbProof.RawItems, &pb.MerkleTreeProofItem{
|
||||
Floor: item.Floor,
|
||||
Hash: item.Hash,
|
||||
Left: item.Left,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 ColRootItem
|
||||
if len(proof.ColRootItem) > 0 {
|
||||
pbProof.ColRootItem = make([]*pb.MerkleTreeProofItem, 0, len(proof.ColRootItem))
|
||||
for _, item := range proof.ColRootItem {
|
||||
pbProof.ColRootItem = append(pbProof.ColRootItem, &pb.MerkleTreeProofItem{
|
||||
Floor: item.Floor,
|
||||
Hash: item.Hash,
|
||||
Left: item.Left,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 RawRootItem
|
||||
if len(proof.RawRootItem) > 0 {
|
||||
pbProof.RawRootItem = make([]*pb.MerkleTreeProofItem, 0, len(proof.RawRootItem))
|
||||
for _, item := range proof.RawRootItem {
|
||||
pbProof.RawRootItem = append(pbProof.RawRootItem, &pb.MerkleTreeProofItem{
|
||||
Floor: item.Floor,
|
||||
Hash: item.Hash,
|
||||
Left: item.Left,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return pbProof
|
||||
}
|
||||
349
api/model/proof_test.go
Normal file
349
api/model/proof_test.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/grpc/pb"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestProofFromProtobuf_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := model.ProofFromProtobuf(nil)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Empty(t, result.Sign)
|
||||
assert.Empty(t, result.Version)
|
||||
assert.Nil(t, result.ColItems)
|
||||
assert.Nil(t, result.RawItems)
|
||||
assert.Nil(t, result.ColRootItem)
|
||||
assert.Nil(t, result.RawRootItem)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_WithSign(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
Sign: "test-signature",
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "test-signature", result.Sign)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_WithVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
Version: "v1.0.0",
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "v1.0.0", result.Version)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_WithColItems(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
ColItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "hash1", Left: true},
|
||||
{Floor: 2, Hash: "hash2", Left: false},
|
||||
},
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.ColItems, 2)
|
||||
assert.Equal(t, uint32(1), result.ColItems[0].Floor)
|
||||
assert.Equal(t, "hash1", result.ColItems[0].Hash)
|
||||
assert.True(t, result.ColItems[0].Left)
|
||||
assert.Equal(t, uint32(2), result.ColItems[1].Floor)
|
||||
assert.Equal(t, "hash2", result.ColItems[1].Hash)
|
||||
assert.False(t, result.ColItems[1].Left)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_WithRawItems(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
RawItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 3, Hash: "hash3", Left: true},
|
||||
},
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.RawItems, 1)
|
||||
assert.Equal(t, uint32(3), result.RawItems[0].Floor)
|
||||
assert.Equal(t, "hash3", result.RawItems[0].Hash)
|
||||
assert.True(t, result.RawItems[0].Left)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_WithColRootItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
ColRootItem: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 4, Hash: "hash4", Left: false},
|
||||
},
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.ColRootItem, 1)
|
||||
assert.Equal(t, uint32(4), result.ColRootItem[0].Floor)
|
||||
assert.Equal(t, "hash4", result.ColRootItem[0].Hash)
|
||||
assert.False(t, result.ColRootItem[0].Left)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_WithRawRootItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
RawRootItem: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 5, Hash: "hash5", Left: true},
|
||||
},
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.RawRootItem, 1)
|
||||
assert.Equal(t, uint32(5), result.RawRootItem[0].Floor)
|
||||
assert.Equal(t, "hash5", result.RawRootItem[0].Hash)
|
||||
assert.True(t, result.RawRootItem[0].Left)
|
||||
}
|
||||
|
||||
func TestProofFromProtobuf_Full(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pbProof := &pb.Proof{
|
||||
Sign: "full-signature",
|
||||
Version: "v1.0.0",
|
||||
ColItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "col1", Left: true},
|
||||
},
|
||||
RawItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 2, Hash: "raw1", Left: false},
|
||||
},
|
||||
ColRootItem: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 3, Hash: "colroot1", Left: true},
|
||||
},
|
||||
RawRootItem: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 4, Hash: "rawroot1", Left: false},
|
||||
},
|
||||
}
|
||||
result := model.ProofFromProtobuf(pbProof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "full-signature", result.Sign)
|
||||
assert.Equal(t, "v1.0.0", result.Version)
|
||||
assert.Len(t, result.ColItems, 1)
|
||||
assert.Len(t, result.RawItems, 1)
|
||||
assert.Len(t, result.ColRootItem, 1)
|
||||
assert.Len(t, result.RawRootItem, 1)
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := model.ProofToProtobuf(nil)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Empty(t, result.GetSign())
|
||||
assert.Empty(t, result.GetVersion())
|
||||
assert.Nil(t, result.GetColItems())
|
||||
assert.Nil(t, result.GetRawItems())
|
||||
assert.Nil(t, result.GetColRootItem())
|
||||
assert.Nil(t, result.GetRawRootItem())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_WithSign(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
Sign: "test-signature",
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "test-signature", result.GetSign())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_WithVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
Version: "v1.0.0",
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "v1.0.0", result.GetVersion())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_WithColItems(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
ColItems: []*model.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "hash1", Left: true},
|
||||
{Floor: 2, Hash: "hash2", Left: false},
|
||||
},
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.GetColItems(), 2)
|
||||
assert.Equal(t, uint32(1), result.GetColItems()[0].GetFloor())
|
||||
assert.Equal(t, "hash1", result.GetColItems()[0].GetHash())
|
||||
assert.True(t, result.GetColItems()[0].GetLeft())
|
||||
assert.Equal(t, uint32(2), result.GetColItems()[1].GetFloor())
|
||||
assert.Equal(t, "hash2", result.GetColItems()[1].GetHash())
|
||||
assert.False(t, result.GetColItems()[1].GetLeft())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_WithRawItems(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
RawItems: []*model.MerkleTreeProofItem{
|
||||
{Floor: 3, Hash: "hash3", Left: true},
|
||||
},
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.GetRawItems(), 1)
|
||||
assert.Equal(t, uint32(3), result.GetRawItems()[0].GetFloor())
|
||||
assert.Equal(t, "hash3", result.GetRawItems()[0].GetHash())
|
||||
assert.True(t, result.GetRawItems()[0].GetLeft())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_WithColRootItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
ColRootItem: []*model.MerkleTreeProofItem{
|
||||
{Floor: 4, Hash: "hash4", Left: false},
|
||||
},
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.GetColRootItem(), 1)
|
||||
assert.Equal(t, uint32(4), result.GetColRootItem()[0].GetFloor())
|
||||
assert.Equal(t, "hash4", result.GetColRootItem()[0].GetHash())
|
||||
assert.False(t, result.GetColRootItem()[0].GetLeft())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_WithRawRootItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
RawRootItem: []*model.MerkleTreeProofItem{
|
||||
{Floor: 5, Hash: "hash5", Left: true},
|
||||
},
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.GetRawRootItem(), 1)
|
||||
assert.Equal(t, uint32(5), result.GetRawRootItem()[0].GetFloor())
|
||||
assert.Equal(t, "hash5", result.GetRawRootItem()[0].GetHash())
|
||||
assert.True(t, result.GetRawRootItem()[0].GetLeft())
|
||||
}
|
||||
|
||||
func TestProofToProtobuf_Full(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proof := &model.Proof{
|
||||
Sign: "full-signature",
|
||||
Version: "v1.0.0",
|
||||
ColItems: []*model.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "col1", Left: true},
|
||||
},
|
||||
RawItems: []*model.MerkleTreeProofItem{
|
||||
{Floor: 2, Hash: "raw1", Left: false},
|
||||
},
|
||||
ColRootItem: []*model.MerkleTreeProofItem{
|
||||
{Floor: 3, Hash: "colroot1", Left: true},
|
||||
},
|
||||
RawRootItem: []*model.MerkleTreeProofItem{
|
||||
{Floor: 4, Hash: "rawroot1", Left: false},
|
||||
},
|
||||
}
|
||||
result := model.ProofToProtobuf(proof)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "full-signature", result.GetSign())
|
||||
assert.Equal(t, "v1.0.0", result.GetVersion())
|
||||
assert.Len(t, result.GetColItems(), 1)
|
||||
assert.Len(t, result.GetRawItems(), 1)
|
||||
assert.Len(t, result.GetColRootItem(), 1)
|
||||
assert.Len(t, result.GetRawRootItem(), 1)
|
||||
}
|
||||
|
||||
func TestProofRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := &pb.Proof{
|
||||
Sign: "round-trip-signature",
|
||||
Version: "v1.0.0",
|
||||
ColItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 1, Hash: "col1", Left: true},
|
||||
{Floor: 2, Hash: "col2", Left: false},
|
||||
},
|
||||
RawItems: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 3, Hash: "raw1", Left: true},
|
||||
},
|
||||
ColRootItem: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 4, Hash: "colroot1", Left: false},
|
||||
},
|
||||
RawRootItem: []*pb.MerkleTreeProofItem{
|
||||
{Floor: 5, Hash: "rawroot1", Left: true},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert to model
|
||||
modelProof := model.ProofFromProtobuf(original)
|
||||
require.NotNil(t, modelProof)
|
||||
|
||||
// Convert back to protobuf
|
||||
pbProof := model.ProofToProtobuf(modelProof)
|
||||
require.NotNil(t, pbProof)
|
||||
|
||||
// Verify round trip
|
||||
assert.Equal(t, original.GetSign(), pbProof.GetSign())
|
||||
assert.Equal(t, original.GetVersion(), pbProof.GetVersion())
|
||||
assert.Len(t, pbProof.GetColItems(), 2)
|
||||
assert.Len(t, pbProof.GetRawItems(), 1)
|
||||
assert.Len(t, pbProof.GetColRootItem(), 1)
|
||||
assert.Len(t, pbProof.GetRawRootItem(), 1)
|
||||
|
||||
assert.Equal(t, original.GetColItems()[0].GetFloor(), pbProof.GetColItems()[0].GetFloor())
|
||||
assert.Equal(t, original.GetColItems()[0].GetHash(), pbProof.GetColItems()[0].GetHash())
|
||||
assert.Equal(t, original.GetColItems()[0].GetLeft(), pbProof.GetColItems()[0].GetLeft())
|
||||
}
|
||||
348
api/model/record.go
Normal file
348
api/model/record.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/internal/helpers"
|
||||
)
|
||||
|
||||
// Record 表示一条记录。
|
||||
// 用于记录系统中的操作行为,包含记录标识、节点前缀、操作者信息等。
|
||||
type Record struct {
|
||||
ID string `json:"id" validate:"required,max=128"`
|
||||
DoPrefix string `json:"doPrefix" validate:"max=512"`
|
||||
ProducerID string `json:"producerId" validate:"required,max=512"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Operator string `json:"operator" validate:"max=64"`
|
||||
Extra []byte `json:"extra" validate:"max=512"`
|
||||
RCType string `json:"type" validate:"max=64"`
|
||||
binary []byte
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 构造函数 =====
|
||||
//
|
||||
|
||||
// NewFullRecord 创建包含所有字段的完整 Record。
|
||||
// 自动完成字段校验,确保创建的 Record 是完整且有效的。
|
||||
func NewFullRecord(
|
||||
doPrefix string,
|
||||
producerID string,
|
||||
timestamp time.Time,
|
||||
operator string,
|
||||
extra []byte,
|
||||
rcType string,
|
||||
) (*Record, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new full record",
|
||||
"doPrefix", doPrefix,
|
||||
"producerID", producerID,
|
||||
"operator", operator,
|
||||
"rcType", rcType,
|
||||
"extraLength", len(extra),
|
||||
)
|
||||
record := &Record{
|
||||
DoPrefix: doPrefix,
|
||||
ProducerID: producerID,
|
||||
Timestamp: timestamp,
|
||||
Operator: operator,
|
||||
Extra: extra,
|
||||
RCType: rcType,
|
||||
}
|
||||
|
||||
log.Debug("Checking and initializing record")
|
||||
if err := record.CheckAndInit(); err != nil {
|
||||
log.Error("Failed to check and init record",
|
||||
"error", err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("Full record created successfully",
|
||||
"recordID", record.ID,
|
||||
)
|
||||
return record, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 接口实现 =====
|
||||
//
|
||||
|
||||
func (r *Record) Key() string {
|
||||
return r.ID
|
||||
}
|
||||
|
||||
// RecordHashData 实现 HashData 接口,用于存储 Record 的哈希计算结果。
|
||||
type RecordHashData struct {
|
||||
key string
|
||||
hash string
|
||||
}
|
||||
|
||||
func (r RecordHashData) Key() string {
|
||||
return r.key
|
||||
}
|
||||
|
||||
func (r RecordHashData) Hash() string {
|
||||
return r.hash
|
||||
}
|
||||
|
||||
func (r RecordHashData) Type() HashType {
|
||||
return Sha256Simd
|
||||
}
|
||||
|
||||
// DoHash 计算 Record 的整体哈希值,用于数据完整性验证。
|
||||
// 哈希基于序列化后的二进制数据计算,确保记录数据的不可篡改性。
|
||||
func (r *Record) DoHash(_ context.Context) (HashData, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Computing hash for record",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
hashTool := GetHashTool(Sha256Simd)
|
||||
binary, err := r.MarshalBinary()
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal record for hash",
|
||||
"error", err,
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to marshal record: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Computing hash bytes",
|
||||
"recordID", r.ID,
|
||||
"binaryLength", len(binary),
|
||||
)
|
||||
hash, err := hashTool.HashBytes(binary)
|
||||
if err != nil {
|
||||
log.Error("Failed to compute hash",
|
||||
"error", err,
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to compute hash: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Hash computed successfully",
|
||||
"recordID", r.ID,
|
||||
"hash", hash,
|
||||
)
|
||||
return RecordHashData{
|
||||
key: r.ID,
|
||||
hash: hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== CBOR 序列化相关 =====
|
||||
//
|
||||
|
||||
// recordData 用于 CBOR 序列化/反序列化的中间结构。
|
||||
// 排除缓存字段,仅包含可序列化的数据字段。
|
||||
type recordData struct {
|
||||
ID *string `cbor:"id"`
|
||||
DoPrefix *string `cbor:"doPrefix"`
|
||||
ProducerID *string `cbor:"producerId"`
|
||||
Timestamp *time.Time `cbor:"timestamp"`
|
||||
Operator *string `cbor:"operator"`
|
||||
Extra []byte `cbor:"extra"`
|
||||
RCType *string `cbor:"type"`
|
||||
}
|
||||
|
||||
// toRecordData 将 Record 转换为 recordData,用于序列化。
|
||||
func (r *Record) toRecordData() *recordData {
|
||||
return &recordData{
|
||||
ID: &r.ID,
|
||||
DoPrefix: &r.DoPrefix,
|
||||
ProducerID: &r.ProducerID,
|
||||
Timestamp: &r.Timestamp,
|
||||
Operator: &r.Operator,
|
||||
Extra: r.Extra,
|
||||
RCType: &r.RCType,
|
||||
}
|
||||
}
|
||||
|
||||
// fromRecordData 从 recordData 填充 Record,用于反序列化。
|
||||
func (r *Record) fromRecordData(recData *recordData) {
|
||||
if recData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if recData.ID != nil {
|
||||
r.ID = *recData.ID
|
||||
}
|
||||
if recData.DoPrefix != nil {
|
||||
r.DoPrefix = *recData.DoPrefix
|
||||
}
|
||||
if recData.ProducerID != nil {
|
||||
r.ProducerID = *recData.ProducerID
|
||||
}
|
||||
if recData.Timestamp != nil {
|
||||
r.Timestamp = *recData.Timestamp
|
||||
}
|
||||
if recData.Operator != nil {
|
||||
r.Operator = *recData.Operator
|
||||
}
|
||||
if recData.Extra != nil {
|
||||
r.Extra = recData.Extra
|
||||
}
|
||||
if recData.RCType != nil {
|
||||
r.RCType = *recData.RCType
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalBinary 将 Record 序列化为 CBOR 格式的二进制数据。
|
||||
// 实现 encoding.BinaryMarshaler 接口。
|
||||
// 使用 Canonical CBOR 编码确保序列化结果的一致性,使用缓存机制避免重复序列化。
|
||||
func (r *Record) MarshalBinary() ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Marshaling record to CBOR binary",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
if r.binary != nil {
|
||||
log.Debug("Using cached binary data",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return r.binary, nil
|
||||
}
|
||||
|
||||
recData := r.toRecordData()
|
||||
|
||||
log.Debug("Marshaling record data to canonical CBOR",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
binary, err := helpers.MarshalCanonical(recData)
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal record to CBOR",
|
||||
"error", err,
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return nil, fmt.Errorf("failed to marshal record to CBOR: %w", err)
|
||||
}
|
||||
|
||||
r.binary = binary
|
||||
|
||||
log.Debug("Record marshaled successfully",
|
||||
"recordID", r.ID,
|
||||
"binaryLength", len(binary),
|
||||
)
|
||||
return binary, nil
|
||||
}
|
||||
|
||||
// UnmarshalBinary 从 CBOR 格式的二进制数据反序列化为 Record。
|
||||
// 实现 encoding.BinaryUnmarshaler 接口。
|
||||
func (r *Record) UnmarshalBinary(data []byte) error {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Unmarshaling record from CBOR binary",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
if len(data) == 0 {
|
||||
log.Error("Data is empty")
|
||||
return errors.New("data is empty")
|
||||
}
|
||||
|
||||
recData := &recordData{}
|
||||
|
||||
log.Debug("Unmarshaling record data from CBOR")
|
||||
if err := helpers.Unmarshal(data, recData); err != nil {
|
||||
log.Error("Failed to unmarshal record from CBOR",
|
||||
"error", err,
|
||||
)
|
||||
return fmt.Errorf("failed to unmarshal record from CBOR: %w", err)
|
||||
}
|
||||
|
||||
r.fromRecordData(recData)
|
||||
|
||||
r.binary = data
|
||||
|
||||
log.Debug("Record unmarshaled successfully",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDoPrefix 实现 DoPrefixExtractor 接口,返回节点前缀。
|
||||
func (r *Record) GetDoPrefix() string {
|
||||
return r.DoPrefix
|
||||
}
|
||||
|
||||
// GetProducerID 返回 ProducerID,实现 Trustlog 接口。
|
||||
func (r *Record) GetProducerID() string {
|
||||
return r.ProducerID
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 初始化与验证 =====
|
||||
//
|
||||
|
||||
// CheckAndInit 校验并初始化 Record。
|
||||
// 自动填充缺失字段(ID),字段非空验证由 validate 标签处理。
|
||||
func (r *Record) CheckAndInit() error {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Checking and initializing record",
|
||||
"producerID", r.ProducerID,
|
||||
"doPrefix", r.DoPrefix,
|
||||
)
|
||||
if r.ID == "" {
|
||||
r.ID = helpers.NewUUIDv7()
|
||||
log.Debug("Generated new record ID",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
}
|
||||
|
||||
if r.Timestamp.IsZero() {
|
||||
r.Timestamp = time.Now()
|
||||
log.Debug("Set default timestamp",
|
||||
"timestamp", r.Timestamp,
|
||||
)
|
||||
}
|
||||
|
||||
log.Debug("Validating record struct")
|
||||
if err := helpers.GetValidator().Struct(r); err != nil {
|
||||
log.Error("Record validation failed",
|
||||
"error", err,
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Record checked and initialized successfully",
|
||||
"recordID", r.ID,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
// ===== 链式调用支持 =====
|
||||
//
|
||||
|
||||
// WithDoPrefix 设置 DoPrefix 并返回自身,支持链式调用。
|
||||
func (r *Record) WithDoPrefix(doPrefix string) *Record {
|
||||
r.DoPrefix = doPrefix
|
||||
return r
|
||||
}
|
||||
|
||||
// WithTimestamp 设置 Timestamp 并返回自身,支持链式调用。
|
||||
func (r *Record) WithTimestamp(timestamp time.Time) *Record {
|
||||
r.Timestamp = timestamp
|
||||
return r
|
||||
}
|
||||
|
||||
// WithOperator 设置 Operator 并返回自身,支持链式调用。
|
||||
func (r *Record) WithOperator(operator string) *Record {
|
||||
r.Operator = operator
|
||||
return r
|
||||
}
|
||||
|
||||
// WithExtra 设置 Extra 并返回自身,支持链式调用。
|
||||
func (r *Record) WithExtra(extra []byte) *Record {
|
||||
r.Extra = extra
|
||||
return r
|
||||
}
|
||||
|
||||
// WithRCType 设置 RCType 并返回自身,支持链式调用。
|
||||
func (r *Record) WithRCType(rcType string) *Record {
|
||||
r.RCType = rcType
|
||||
return r
|
||||
}
|
||||
321
api/model/record_test.go
Normal file
321
api/model/record_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestRecord_Key(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
ID: "test-record-id",
|
||||
}
|
||||
assert.Equal(t, "test-record-id", rec.Key())
|
||||
}
|
||||
|
||||
func TestNewFullRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
rec, err := model.NewFullRecord(
|
||||
"test-prefix",
|
||||
"producer-1",
|
||||
now,
|
||||
"operator-1",
|
||||
[]byte("extra"),
|
||||
"log",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, rec)
|
||||
assert.NotEmpty(t, rec.ID)
|
||||
assert.Equal(t, "test-prefix", rec.DoPrefix)
|
||||
assert.Equal(t, "producer-1", rec.ProducerID)
|
||||
assert.Equal(t, now.Unix(), rec.Timestamp.Unix())
|
||||
assert.Equal(t, "operator-1", rec.Operator)
|
||||
assert.Equal(t, []byte("extra"), rec.Extra)
|
||||
assert.Equal(t, "log", rec.RCType)
|
||||
}
|
||||
|
||||
func TestNewFullRecord_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
// Missing required ProducerID
|
||||
rec, err := model.NewFullRecord(
|
||||
"test-prefix",
|
||||
"", // Empty ProducerID
|
||||
now,
|
||||
"operator-1",
|
||||
[]byte("extra"),
|
||||
"log",
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, rec)
|
||||
}
|
||||
|
||||
func TestRecord_CheckAndInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rec *model.Record
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid record",
|
||||
rec: &model.Record{
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "auto generate ID",
|
||||
rec: &model.Record{
|
||||
ID: "", // Will be auto-generated
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing ProducerID",
|
||||
rec: &model.Record{
|
||||
DoPrefix: "test",
|
||||
ProducerID: "", // Required field
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tt.rec.CheckAndInit()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if tt.name == "auto generate ID" {
|
||||
assert.NotEmpty(t, tt.rec.ID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecord_MarshalUnmarshalBinary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
err := original.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Marshal
|
||||
data, err := original.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// Unmarshal
|
||||
result := &model.Record{}
|
||||
err = result.UnmarshalBinary(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
assert.Equal(t, original.ID, result.ID)
|
||||
assert.Equal(t, original.DoPrefix, result.DoPrefix)
|
||||
assert.Equal(t, original.ProducerID, result.ProducerID)
|
||||
assert.Equal(t, original.Timestamp.Unix(), result.Timestamp.Unix())
|
||||
// 验证纳秒精度被保留
|
||||
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||
"时间戳的纳秒精度应该被保留")
|
||||
assert.Equal(t, original.Operator, result.Operator)
|
||||
assert.Equal(t, original.Extra, result.Extra)
|
||||
assert.Equal(t, original.RCType, result.RCType)
|
||||
}
|
||||
|
||||
func TestRecord_MarshalBinary_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
// MarshalBinary should succeed even without CheckAndInit
|
||||
// It just serializes the data
|
||||
data, err := rec.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, data)
|
||||
}
|
||||
|
||||
func TestRecord_UnmarshalBinary_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
err := rec.UnmarshalBinary([]byte{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRecord_DoHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
err := rec.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
hashData, err := rec.DoHash(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, hashData)
|
||||
assert.Equal(t, rec.ID, hashData.Key())
|
||||
assert.NotEmpty(t, hashData.Hash())
|
||||
}
|
||||
|
||||
func TestRecordHashData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// RecordHashData is created through DoHash, test it indirectly
|
||||
rec := &model.Record{
|
||||
ID: "rec-123",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: time.Now(),
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
err := rec.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
hashData, err := rec.DoHash(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, hashData)
|
||||
assert.Equal(t, "rec-123", hashData.Key())
|
||||
assert.NotEmpty(t, hashData.Hash())
|
||||
assert.Equal(t, model.Sha256Simd, hashData.Type())
|
||||
}
|
||||
|
||||
func TestRecord_GetProducerID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
ProducerID: "producer-123",
|
||||
}
|
||||
assert.Equal(t, "producer-123", rec.GetProducerID())
|
||||
}
|
||||
|
||||
func TestRecord_GetDoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{
|
||||
DoPrefix: "test-prefix",
|
||||
}
|
||||
assert.Equal(t, "test-prefix", rec.GetDoPrefix())
|
||||
}
|
||||
|
||||
func TestRecord_WithDoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
result := rec.WithDoPrefix("test-prefix")
|
||||
assert.Equal(t, rec, result)
|
||||
assert.Equal(t, "test-prefix", rec.DoPrefix)
|
||||
}
|
||||
|
||||
func TestRecord_WithTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
now := time.Now()
|
||||
result := rec.WithTimestamp(now)
|
||||
assert.Equal(t, rec, result)
|
||||
assert.Equal(t, now.Unix(), rec.Timestamp.Unix())
|
||||
}
|
||||
|
||||
func TestRecord_WithOperator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
result := rec.WithOperator("operator-1")
|
||||
assert.Equal(t, rec, result)
|
||||
assert.Equal(t, "operator-1", rec.Operator)
|
||||
}
|
||||
|
||||
func TestRecord_WithExtra(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
extra := []byte("extra-data")
|
||||
result := rec.WithExtra(extra)
|
||||
assert.Equal(t, rec, result)
|
||||
assert.Equal(t, extra, rec.Extra)
|
||||
}
|
||||
|
||||
func TestRecord_WithRCType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
result := rec.WithRCType("log")
|
||||
assert.Equal(t, rec, result)
|
||||
assert.Equal(t, "log", rec.RCType)
|
||||
}
|
||||
|
||||
func TestRecord_ChainedMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := &model.Record{}
|
||||
now := time.Now()
|
||||
result := rec.
|
||||
WithDoPrefix("prefix").
|
||||
WithTimestamp(now).
|
||||
WithOperator("operator").
|
||||
WithExtra([]byte("extra")).
|
||||
WithRCType("log")
|
||||
|
||||
assert.Equal(t, rec, result)
|
||||
assert.Equal(t, "prefix", rec.DoPrefix)
|
||||
assert.Equal(t, now.Unix(), rec.Timestamp.Unix())
|
||||
assert.Equal(t, "operator", rec.Operator)
|
||||
assert.Equal(t, []byte("extra"), rec.Extra)
|
||||
assert.Equal(t, "log", rec.RCType)
|
||||
}
|
||||
54
api/model/record_timestamp_test.go
Normal file
54
api/model/record_timestamp_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
// TestRecord_TimestampNanosecondPrecision 验证 Record 的时间戳在 CBOR 序列化/反序列化后能保留纳秒精度
|
||||
func TestRecord_TimestampNanosecondPrecision(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 创建一个包含纳秒精度的时间戳
|
||||
timestamp := time.Date(2024, 1, 1, 12, 30, 45, 123456789, time.UTC)
|
||||
|
||||
original := &model.Record{
|
||||
ID: "rec-nanosecond-test",
|
||||
DoPrefix: "test",
|
||||
ProducerID: "producer-1",
|
||||
Timestamp: timestamp,
|
||||
Operator: "operator-1",
|
||||
Extra: []byte("extra"),
|
||||
RCType: "log",
|
||||
}
|
||||
|
||||
err := original.CheckAndInit()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Original timestamp: %v", original.Timestamp)
|
||||
t.Logf("Original nanoseconds: %d", original.Timestamp.Nanosecond())
|
||||
|
||||
// 序列化
|
||||
data, err := original.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data)
|
||||
|
||||
// 反序列化
|
||||
result := &model.Record{}
|
||||
err = result.UnmarshalBinary(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Decoded timestamp: %v", result.Timestamp)
|
||||
t.Logf("Decoded nanoseconds: %d", result.Timestamp.Nanosecond())
|
||||
|
||||
// 验证纳秒精度被完整保留
|
||||
assert.Equal(t, original.Timestamp.UnixNano(), result.Timestamp.UnixNano(),
|
||||
"时间戳的纳秒精度应该被完整保留")
|
||||
assert.Equal(t, original.Timestamp.Nanosecond(), result.Timestamp.Nanosecond(),
|
||||
"纳秒部分应该相等")
|
||||
}
|
||||
393
api/model/signature.go
Normal file
393
api/model/signature.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/crpt/go-crpt"
|
||||
_ "github.com/crpt/go-crpt/sm2" // Import SM2 to register it
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivateKeyIsNil = errors.New("private key is nil")
|
||||
ErrPublicAndKeysNotMatch = errors.New("public and private keys don't match")
|
||||
)
|
||||
|
||||
// ComputeSignature 计算SM2签名.
|
||||
// 这是 SDK 默认的签名函数,使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||
//
|
||||
// 参数:
|
||||
// - data: 待签名的原始数据
|
||||
// - privateKeyDER: 私钥的DER编码字节数组
|
||||
//
|
||||
// 返回: 签名字节数组.
|
||||
// 注意: go-crpt 库会自动使用 SM3 算法计算摘要并签名。
|
||||
func ComputeSignature(data, privateKeyDER []byte) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Computing SM2 signature",
|
||||
"dataLength", len(data),
|
||||
"privateKeyDERLength", len(privateKeyDER),
|
||||
)
|
||||
|
||||
if len(privateKeyDER) == 0 {
|
||||
log.Error("Private key is empty")
|
||||
return nil, errors.New("private key cannot be empty")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
log.Error("Data to sign is empty")
|
||||
return nil, errors.New("data to sign cannot be empty")
|
||||
}
|
||||
|
||||
// 解析DER格式的私钥
|
||||
log.Debug("Parsing SM2 private key from DER format")
|
||||
privateKey, err := crpt.PrivateKeyFromBytes(crpt.SM2, privateKeyDER)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse SM2 private key",
|
||||
"error", err,
|
||||
"keyLength", len(privateKeyDER),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to parse SM2 private key (key length: %d): %w", len(privateKeyDER), err)
|
||||
}
|
||||
|
||||
if privateKey == nil {
|
||||
log.Error("Parsed private key is nil")
|
||||
return nil, ErrPrivateKeyIsNil
|
||||
}
|
||||
|
||||
// 使用SM2签名(ASN.1编码),go-crpt 库会自动使用 SM3 计算摘要
|
||||
log.Debug("Signing raw data with SM2 using ASN.1 encoding (SM3 hash)")
|
||||
signature, err := crpt.SignMessage(privateKey, data, rand.Reader, nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign data with SM2",
|
||||
"error", err,
|
||||
"dataLength", len(data),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to sign data with SM2 (data length: %d): %w", len(data), err)
|
||||
}
|
||||
|
||||
log.Debug("SM2 signature computed successfully",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// VerifySignature 验证SM2签名.
|
||||
// 这是 SDK 默认的验签函数,使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||
//
|
||||
// 参数:
|
||||
// - data: 原始数据
|
||||
// - publicKeyDER: 公钥的DER编码字节数组
|
||||
// - signature: 签名字节数组
|
||||
//
|
||||
// 返回: 验证是否成功和可能的错误.
|
||||
// 注意: go-crpt 库会自动使用 SM3 算法计算摘要并验证。
|
||||
func VerifySignature(data, publicKeyDER, signature []byte) (bool, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying SM2 signature",
|
||||
"dataLength", len(data),
|
||||
"publicKeyDERLength", len(publicKeyDER),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
|
||||
if len(publicKeyDER) == 0 {
|
||||
log.Error("Public key is empty")
|
||||
return false, errors.New("public key cannot be empty")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
log.Error("Data to verify is empty")
|
||||
return false, errors.New("data to verify cannot be empty")
|
||||
}
|
||||
|
||||
if len(signature) == 0 {
|
||||
log.Error("Signature is empty")
|
||||
return false, errors.New("signature cannot be empty")
|
||||
}
|
||||
|
||||
// 解析DER格式的公钥,复用ParseSM2PublicDER以避免代码重复
|
||||
log.Debug("Parsing SM2 public key from DER format")
|
||||
publicKey, err := ParseSM2PublicDER(publicKeyDER)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse SM2 public key",
|
||||
"error", err,
|
||||
"keyLength", len(publicKeyDER),
|
||||
)
|
||||
return false, fmt.Errorf("failed to parse SM2 public key (key length: %d): %w", len(publicKeyDER), err)
|
||||
}
|
||||
|
||||
// 验证签名(ASN.1编码),go-crpt 库会自动使用 SM3 计算摘要
|
||||
log.Debug("Verifying signature with SM2 using ASN.1 encoding (SM3 hash)")
|
||||
ok, err := crpt.VerifyMessage(publicKey, data, crpt.Signature(signature), nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to verify SM2 signature",
|
||||
"error", err,
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return false, fmt.Errorf("failed to verify signature: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
log.Warn("SM2 signature verification failed",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return false, fmt.Errorf(
|
||||
"signature verification failed (data length: %d, signature length: %d)",
|
||||
len(data), len(signature),
|
||||
)
|
||||
}
|
||||
log.Debug("SM2 signature verified successfully",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GenerateSM2KeyPair 生成SM2密钥对.
|
||||
// 这是 SDK 默认推荐的密钥生成方法。
|
||||
//
|
||||
// 返回新生成的密钥对,包含公钥和私钥.
|
||||
// SM2 算法会在签名时自动使用 SM3 哈希。
|
||||
func GenerateSM2KeyPair() (*SM2KeyPair, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Generating SM2 key pair")
|
||||
pub, priv, err := crpt.GenerateKey(crpt.SM2, rand.Reader)
|
||||
if err != nil {
|
||||
log.Error("Failed to generate SM2 key pair", "error", err)
|
||||
return nil, fmt.Errorf("failed to generate SM2 key pair: %w", err)
|
||||
}
|
||||
|
||||
if priv == nil {
|
||||
log.Error("Generated private key is nil")
|
||||
return nil, errors.New("generated private key is nil")
|
||||
}
|
||||
|
||||
log.Debug("SM2 key pair generated successfully")
|
||||
return &SM2KeyPair{
|
||||
Public: pub,
|
||||
Private: priv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SM2KeyPair SM2密钥对,包含公钥和私钥.
|
||||
type SM2KeyPair struct {
|
||||
Public crpt.PublicKey `json:"publicKey"`
|
||||
Private crpt.PrivateKey `json:"privateKey"`
|
||||
}
|
||||
|
||||
// MarshalSM2PrivateDER 将私钥编码为DER格式.
|
||||
// 将SM2私钥转换为DER格式的字节数组,用于存储或传输.
|
||||
func MarshalSM2PrivateDER(priv crpt.PrivateKey) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Marshaling SM2 private key to DER format")
|
||||
if priv == nil {
|
||||
log.Error("Private key is nil")
|
||||
return nil, errors.New("private key is nil")
|
||||
}
|
||||
|
||||
der := priv.Bytes()
|
||||
log.Debug("SM2 private key marshaled to DER successfully",
|
||||
"derLength", len(der),
|
||||
)
|
||||
return der, nil
|
||||
}
|
||||
|
||||
// ParseSM2PrivateDER 从DER格式解析私钥.
|
||||
// 将DER格式的字节数组解析为SM2私钥对象.
|
||||
func ParseSM2PrivateDER(der []byte) (crpt.PrivateKey, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Parsing SM2 private key from DER format",
|
||||
"derLength", len(der),
|
||||
)
|
||||
if len(der) == 0 {
|
||||
log.Error("DER encoded private key is empty")
|
||||
return nil, errors.New("DER encoded private key cannot be empty")
|
||||
}
|
||||
|
||||
key, err := crpt.PrivateKeyFromBytes(crpt.SM2, der)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse SM2 private key from DER",
|
||||
"error", err,
|
||||
"derLength", len(der),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to parse SM2 private key from DER (length: %d): %w", len(der), err)
|
||||
}
|
||||
log.Debug("SM2 private key parsed from DER successfully")
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// MarshalSM2PublicDER 将公钥编码为DER格式.
|
||||
// 将SM2公钥转换为DER格式的字节数组,用于存储或传输.
|
||||
func MarshalSM2PublicDER(pub crpt.PublicKey) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Marshaling SM2 public key to DER format")
|
||||
if pub == nil {
|
||||
log.Error("Public key is nil")
|
||||
return nil, errors.New("public key is nil")
|
||||
}
|
||||
|
||||
der := pub.Bytes()
|
||||
log.Debug("SM2 public key marshaled to DER successfully",
|
||||
"derLength", len(der),
|
||||
)
|
||||
return der, nil
|
||||
}
|
||||
|
||||
// ParseSM2PublicDER 从DER格式解析公钥.
|
||||
// 将DER格式的字节数组解析为SM2公钥对象.
|
||||
// 返回解析后的公钥,如果数据不是有效的SM2公钥则返回错误.
|
||||
func ParseSM2PublicDER(der []byte) (crpt.PublicKey, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Parsing SM2 public key from DER format",
|
||||
"derLength", len(der),
|
||||
)
|
||||
if len(der) == 0 {
|
||||
log.Error("DER encoded public key is empty")
|
||||
return nil, errors.New("DER encoded public key cannot be empty")
|
||||
}
|
||||
|
||||
publicKey, err := crpt.PublicKeyFromBytes(crpt.SM2, der)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse SM2 public key",
|
||||
"error", err,
|
||||
"derLength", len(der),
|
||||
)
|
||||
return nil, fmt.Errorf("failed to parse SM2 public key (DER length: %d): %w", len(der), err)
|
||||
}
|
||||
|
||||
log.Debug("SM2 public key parsed from DER successfully")
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// SignMessage 使用密钥对签名消息(标准SM2签名).
|
||||
// 使用标准SM2算法对消息进行签名,不包含用户标识(uid).
|
||||
func (kp *SM2KeyPair) SignMessage(msg []byte) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Signing message with SM2 key pair",
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
if kp.Private == nil {
|
||||
log.Error("Private key is nil")
|
||||
return nil, ErrPrivateKeyIsNil
|
||||
}
|
||||
|
||||
signature, err := crpt.SignMessage(kp.Private, msg, rand.Reader, nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign message with SM2",
|
||||
"error", err,
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("Message signed successfully with SM2",
|
||||
"messageLength", len(msg),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// SignGM 使用密钥对签名消息(国密标准SM2签名,带uid).
|
||||
// 使用符合GB/T 32918标准的SM2算法对消息进行签名,包含用户标识(uid).
|
||||
// uid用于Z值计算,通常为用户ID或标识符.
|
||||
func (kp *SM2KeyPair) SignGM(msg, uid []byte) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Signing message with SM2 GM standard",
|
||||
"messageLength", len(msg),
|
||||
"uidLength", len(uid),
|
||||
)
|
||||
if kp.Private == nil {
|
||||
log.Error("Private key is nil")
|
||||
return nil, ErrPrivateKeyIsNil
|
||||
}
|
||||
|
||||
// go-crpt uses SM3 hash internally, pass nil for standard signing
|
||||
signature, err := crpt.SignMessage(kp.Private, msg, rand.Reader, nil)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign message with SM2 GM standard",
|
||||
"error", err,
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("Message signed successfully with SM2 GM standard",
|
||||
"messageLength", len(msg),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// VerifyMessage 使用公钥验证签名(标准SM2验签).
|
||||
// 验证标准SM2签名,不使用用户标识(uid).
|
||||
// 返回验证结果和可能的错误.如果验证失败但没有错误发生,返回(false, nil).
|
||||
func (kp *SM2KeyPair) VerifyMessage(msg, sig []byte) (bool, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying message signature with SM2",
|
||||
"messageLength", len(msg),
|
||||
"signatureLength", len(sig),
|
||||
)
|
||||
if kp.Public == nil {
|
||||
log.Error("Public key is nil")
|
||||
return false, ErrPublicAndKeysNotMatch
|
||||
}
|
||||
|
||||
ok, err := crpt.VerifyMessage(kp.Public, msg, crpt.Signature(sig), nil)
|
||||
if err != nil {
|
||||
log.Error("Error verifying message with SM2",
|
||||
"error", err,
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
return false, err
|
||||
}
|
||||
if ok {
|
||||
log.Debug("Message signature verified successfully with SM2",
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
} else {
|
||||
log.Warn("Message signature verification failed with SM2",
|
||||
"messageLength", len(msg),
|
||||
"signatureLength", len(sig),
|
||||
)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// VerifyGM 使用公钥验证签名(国密标准SM2验签,带uid).
|
||||
// 验证符合GB/T 32918标准的SM2签名,使用用户标识(uid).
|
||||
// 返回验证结果和可能的错误.如果验证失败但没有错误发生,返回(false, nil).
|
||||
func (kp *SM2KeyPair) VerifyGM(msg, sig, uid []byte) (bool, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying message signature with SM2 GM standard",
|
||||
"messageLength", len(msg),
|
||||
"signatureLength", len(sig),
|
||||
"uidLength", len(uid),
|
||||
)
|
||||
if kp.Public == nil {
|
||||
log.Error("Public key is nil")
|
||||
return false, ErrPublicAndKeysNotMatch
|
||||
}
|
||||
|
||||
// go-crpt uses SM3 hash internally
|
||||
ok, err := crpt.VerifyMessage(kp.Public, msg, crpt.Signature(sig), nil)
|
||||
if err != nil {
|
||||
log.Error("Error verifying message with SM2 GM standard",
|
||||
"error", err,
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
return false, err
|
||||
}
|
||||
if ok {
|
||||
log.Debug("Message signature verified successfully with SM2 GM standard",
|
||||
"messageLength", len(msg),
|
||||
)
|
||||
} else {
|
||||
log.Warn("Message signature verification failed with SM2 GM standard",
|
||||
"messageLength", len(msg),
|
||||
"signatureLength", len(sig),
|
||||
)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
253
api/model/signature_test.go
Normal file
253
api/model/signature_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestComputeSignature_EmptyPrivateKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.ComputeSignature([]byte("data"), nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "private key cannot be empty")
|
||||
}
|
||||
|
||||
func TestComputeSignature_EmptyData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privateKey := []byte("invalid-key")
|
||||
_, err := model.ComputeSignature(nil, privateKey)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data to sign cannot be empty")
|
||||
}
|
||||
|
||||
func TestComputeSignature_InvalidKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.ComputeSignature([]byte("data"), []byte("invalid-key"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse SM2 private key")
|
||||
}
|
||||
|
||||
func TestVerifySignature_EmptyPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.VerifySignature([]byte("data"), nil, []byte("signature"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "public key cannot be empty")
|
||||
}
|
||||
|
||||
func TestVerifySignature_EmptyData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publicKey := []byte("invalid-key")
|
||||
_, err := model.VerifySignature(nil, publicKey, []byte("signature"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data to verify cannot be empty")
|
||||
}
|
||||
|
||||
func TestVerifySignature_InvalidKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publicKey := []byte("invalid-key")
|
||||
valid, err := model.VerifySignature([]byte("data"), publicKey, []byte("signature"))
|
||||
require.Error(t, err)
|
||||
assert.False(t, valid)
|
||||
assert.Contains(t, err.Error(), "failed to parse SM2 public key")
|
||||
}
|
||||
|
||||
func TestGenerateSM2KeyPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, keyPair)
|
||||
assert.NotNil(t, keyPair.Public)
|
||||
assert.NotNil(t, keyPair.Private)
|
||||
}
|
||||
|
||||
func TestMarshalSM2PrivateDER_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.MarshalSM2PrivateDER(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "private key is nil")
|
||||
}
|
||||
|
||||
func TestMarshalSM2PrivateDER(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
der, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, der)
|
||||
assert.NotEmpty(t, der)
|
||||
}
|
||||
|
||||
func TestParseSM2PrivateDER_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.ParseSM2PrivateDER(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "DER encoded private key cannot be empty")
|
||||
}
|
||||
|
||||
func TestParseSM2PrivateDER_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.ParseSM2PrivateDER([]byte("invalid-der"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse SM2 private key from DER")
|
||||
}
|
||||
|
||||
func TestParseSM2PrivateDER_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
der, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKey, err := model.ParseSM2PrivateDER(der)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, parsedKey)
|
||||
}
|
||||
|
||||
func TestMarshalSM2PublicDER_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.MarshalSM2PublicDER(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "public key is nil")
|
||||
}
|
||||
|
||||
func TestMarshalSM2PublicDER(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
der, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, der)
|
||||
assert.NotEmpty(t, der)
|
||||
}
|
||||
|
||||
func TestParseSM2PublicDER_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.ParseSM2PublicDER(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "DER encoded public key cannot be empty")
|
||||
}
|
||||
|
||||
func TestParseSM2PublicDER_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := model.ParseSM2PublicDER([]byte("invalid-der"))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse SM2 public key")
|
||||
}
|
||||
|
||||
func TestParseSM2PublicDER_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
der, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKey, err := model.ParseSM2PublicDER(der)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, parsedKey)
|
||||
}
|
||||
|
||||
func TestSM2SignAndVerify_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate key pair
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Marshal keys
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sign data
|
||||
data := []byte("test data")
|
||||
signature, err := model.ComputeSignature(data, privateKeyDER)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, signature)
|
||||
assert.NotEmpty(t, signature)
|
||||
|
||||
// Verify signature
|
||||
valid, err := model.VerifySignature(data, publicKeyDER, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
}
|
||||
|
||||
func TestSM2SignAndVerify_WrongData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate key pair
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Marshal keys
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sign data
|
||||
data := []byte("test data")
|
||||
signature, err := model.ComputeSignature(data, privateKeyDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify with wrong data
|
||||
wrongData := []byte("wrong data")
|
||||
valid, err := model.VerifySignature(wrongData, publicKeyDER, signature)
|
||||
// Verification should return error
|
||||
require.Error(t, err)
|
||||
assert.False(t, valid)
|
||||
assert.Contains(t, err.Error(), "signature verification failed")
|
||||
}
|
||||
|
||||
func TestSM2SignAndVerify_WrongSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate key pair
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Marshal keys
|
||||
privateKeyDER, err := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyDER, err := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sign data
|
||||
data := []byte("test data")
|
||||
_, err = model.ComputeSignature(data, privateKeyDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify with wrong signature
|
||||
wrongSignature := []byte("wrong signature")
|
||||
valid, err := model.VerifySignature(data, publicKeyDER, wrongSignature)
|
||||
require.Error(t, err) // Should fail verification
|
||||
assert.False(t, valid)
|
||||
}
|
||||
155
api/model/signer.go
Normal file
155
api/model/signer.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// Signer 签名器接口,用于抽象不同的签名算法实现。
|
||||
// 实现了此接口的类型可以提供签名和验签功能。
|
||||
//
|
||||
// SDK 默认使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||
// 可通过 SetGlobalCryptoConfig 切换到其他算法(如 Ed25519)。
|
||||
type Signer interface {
|
||||
// Sign 对数据进行签名。
|
||||
// 参数:
|
||||
// - data: 待签名的原始数据
|
||||
// 返回: 签名字节数组和可能的错误
|
||||
Sign(data []byte) ([]byte, error)
|
||||
|
||||
// Verify 验证签名。
|
||||
// 参数:
|
||||
// - data: 原始数据
|
||||
// - signature: 签名字节数组
|
||||
//
|
||||
// 返回: 验证是否成功和可能的错误
|
||||
Verify(data, signature []byte) (bool, error)
|
||||
}
|
||||
|
||||
// SM2Signer SM2签名器实现。
|
||||
// 使用SM2算法进行签名和验签(内部自动使用 SM3 哈希)。
|
||||
//
|
||||
// 这是 SDK 的默认签名算法。如需使用其他算法,请使用 ConfigSigner。
|
||||
type SM2Signer struct {
|
||||
privateKey []byte // 私钥(DER编码格式)
|
||||
publicKey []byte // 公钥(DER编码格式)
|
||||
}
|
||||
|
||||
// NewSM2Signer 创建新的SM2签名器。
|
||||
// 这是 SDK 默认推荐的签名器,使用 SM2 算法(内部自动使用 SM3 哈希)。
|
||||
//
|
||||
// 参数:
|
||||
// - privateKey: 私钥(DER编码格式),用于签名
|
||||
// - publicKey: 公钥(DER编码格式),用于验签
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// keyPair, _ := model.GenerateSM2KeyPair()
|
||||
// privateKeyDER, _ := model.MarshalSM2PrivateDER(keyPair.Private)
|
||||
// publicKeyDER, _ := model.MarshalSM2PublicDER(keyPair.Public)
|
||||
// signer := model.NewSM2Signer(privateKeyDER, publicKeyDER)
|
||||
func NewSM2Signer(privateKey, publicKey []byte) *SM2Signer {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new SM2 signer (default algorithm, uses SM3 hash)",
|
||||
"privateKeyLength", len(privateKey),
|
||||
"publicKeyLength", len(publicKey),
|
||||
)
|
||||
return &SM2Signer{
|
||||
privateKey: privateKey,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Sign 使用SM2私钥对数据进行签名(内部使用 SM3 哈希)。
|
||||
func (s *SM2Signer) Sign(data []byte) ([]byte, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Signing data with SM2 (using SM3 hash)",
|
||||
"dataLength", len(data),
|
||||
"privateKeyLength", len(s.privateKey),
|
||||
)
|
||||
signature, err := ComputeSignature(data, s.privateKey)
|
||||
if err != nil {
|
||||
log.Error("Failed to sign data with SM2",
|
||||
"error", err,
|
||||
"dataLength", len(data),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("Data signed successfully with SM2",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// Verify 使用SM2公钥验证签名(内部使用 SM3 哈希)。
|
||||
// 注意: go-crpt 库会自动使用 SM3 算法计算摘要并验证。
|
||||
// 返回: 验证是否成功和可能的错误.
|
||||
func (s *SM2Signer) Verify(data, signature []byte) (bool, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Verifying signature with SM2",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
"publicKeyLength", len(s.publicKey),
|
||||
)
|
||||
valid, err := VerifySignature(data, s.publicKey, signature)
|
||||
if err != nil {
|
||||
log.Error("Failed to verify signature with SM2",
|
||||
"error", err,
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
return false, err
|
||||
}
|
||||
if valid {
|
||||
log.Debug("Signature verified successfully with SM2",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
} else {
|
||||
log.Warn("Signature verification failed with SM2",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
}
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
// NopSigner 空操作签名器实现。
|
||||
// 对原hash不做任何操作,直接返回原数据。
|
||||
// 适用于不需要实际签名操作的场景,如测试或某些特殊用途。
|
||||
type NopSigner struct{}
|
||||
|
||||
// NewNopSigner 创建新的空操作签名器。
|
||||
func NewNopSigner() *NopSigner {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("Creating new NopSigner")
|
||||
return &NopSigner{}
|
||||
}
|
||||
|
||||
// Sign 直接返回原数据,不做任何签名操作。
|
||||
func (n *NopSigner) Sign(_ []byte) ([]byte, error) {
|
||||
|
||||
return ([]byte)("test"), nil
|
||||
}
|
||||
|
||||
// Verify 验证签名是否等于原数据。
|
||||
func (n *NopSigner) Verify(data, signature []byte) (bool, error) {
|
||||
log := logger.GetGlobalLogger()
|
||||
log.Debug("NopSigner: verifying signature",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
valid := bytes.Equal(data, signature)
|
||||
if valid {
|
||||
log.Debug("NopSigner: signature verified successfully",
|
||||
"dataLength", len(data),
|
||||
)
|
||||
} else {
|
||||
log.Warn("NopSigner: signature verification failed",
|
||||
"dataLength", len(data),
|
||||
"signatureLength", len(signature),
|
||||
)
|
||||
}
|
||||
return valid, nil
|
||||
}
|
||||
135
api/model/signer_test.go
Normal file
135
api/model/signer_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestNewSM2Signer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privateKey := []byte("test-private-key")
|
||||
publicKey := []byte("test-public-key")
|
||||
|
||||
signer := model.NewSM2Signer(privateKey, publicKey)
|
||||
assert.NotNil(t, signer)
|
||||
}
|
||||
|
||||
func TestNewNopSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
assert.NotNil(t, signer)
|
||||
}
|
||||
|
||||
func TestNopSigner_Sign(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
data := []byte("test data")
|
||||
|
||||
result, err := signer.Sign(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, data, result)
|
||||
assert.NotSame(t, &data[0], &result[0]) // Should be a copy
|
||||
}
|
||||
|
||||
func TestNopSigner_Sign_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
data := []byte{}
|
||||
|
||||
result, err := signer.Sign(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, data, result)
|
||||
}
|
||||
|
||||
func TestNopSigner_Verify_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
data := []byte("test data")
|
||||
signature := []byte("test data") // Same as data
|
||||
|
||||
valid, err := signer.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
}
|
||||
|
||||
func TestNopSigner_Verify_Failure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
data := []byte("test data")
|
||||
signature := []byte("different data")
|
||||
|
||||
valid, err := signer.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
}
|
||||
|
||||
func TestNopSigner_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
data := []byte("test data")
|
||||
|
||||
signature, err := signer.Sign(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
valid, err := signer.Verify(data, signature)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
}
|
||||
|
||||
func TestNopSigner_Verify_DifferentLengths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := model.NewNopSigner()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
signature []byte
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "same data",
|
||||
data: []byte("test"),
|
||||
signature: []byte("test"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different data",
|
||||
data: []byte("test"),
|
||||
signature: []byte("test2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different lengths",
|
||||
data: []byte("test"),
|
||||
signature: []byte("test1"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
data: []byte{},
|
||||
signature: []byte{},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
valid, err := signer.Verify(tt.data, tt.signature)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
65
api/model/sm2_consistency_test.go
Normal file
65
api/model/sm2_consistency_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
// TestSM2HashConsistency 验证SM2加签和验签的一致性
|
||||
// 关键发现:SM2库内部会处理hash,但加签和验签必须使用相同的数据类型.
|
||||
func TestSM2HashConsistency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试数据
|
||||
originalData := []byte("test data for consistency check")
|
||||
|
||||
t.Logf("=== 测试1:加签和验签都使用原始数据(当前实现)===")
|
||||
// 1. 加签:使用原始数据
|
||||
signature1, err := keyPair.SignMessage(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. 验签:使用原始数据
|
||||
valid1, err := keyPair.VerifyMessage(originalData, signature1)
|
||||
require.NoError(t, err)
|
||||
t.Logf("加签(原始数据) + 验签(原始数据): %v", valid1)
|
||||
assert.True(t, valid1, "加签和验签都使用原始数据应该成功")
|
||||
|
||||
t.Logf("\n=== 测试2:加签和验签都使用hash值 ===")
|
||||
// 3. 加签:使用hash值
|
||||
hashBytes := sha256.Sum256(originalData)
|
||||
signature2, err := keyPair.SignMessage(hashBytes[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. 验签:使用hash值
|
||||
valid2, err := keyPair.VerifyMessage(hashBytes[:], signature2)
|
||||
require.NoError(t, err)
|
||||
t.Logf("加签(hash值) + 验签(hash值): %v", valid2)
|
||||
assert.True(t, valid2, "加签和验签都使用hash值应该成功")
|
||||
|
||||
t.Logf("\n=== 测试3:不一致的情况(应该失败)===")
|
||||
// 5. 加签使用原始数据,验签使用hash值 - 应该失败
|
||||
valid3, err := keyPair.VerifyMessage(hashBytes[:], signature1)
|
||||
require.NoError(t, err)
|
||||
t.Logf("加签(原始数据) + 验签(hash值): %v", valid3)
|
||||
assert.False(t, valid3, "加签和验签使用不同类型数据应该失败")
|
||||
|
||||
// 6. 加签使用hash值,验签使用原始数据 - 应该失败
|
||||
valid4, err := keyPair.VerifyMessage(originalData, signature2)
|
||||
require.NoError(t, err)
|
||||
t.Logf("加签(hash值) + 验签(原始数据): %v", valid4)
|
||||
assert.False(t, valid4, "加签和验签使用不同类型数据应该失败")
|
||||
|
||||
t.Logf("\n=== 结论 ===")
|
||||
t.Logf("✓ SM2库内部会处理hash")
|
||||
t.Logf("✓ 加签和验签必须使用相同的数据类型(都是原始数据,或都是hash值)")
|
||||
t.Logf("✓ 当前实现(加签和验签都使用原始数据)是正确的")
|
||||
}
|
||||
82
api/model/sm2_hash_test.go
Normal file
82
api/model/sm2_hash_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
// TestSM2RequiresHash 测试SM2是否要求预先hash数据
|
||||
// 根据文档,SM2.SignASN1期望接收hash值,而不是原始数据
|
||||
// 但文档也提到:如果opts是*SM2SignerOption且ForceGMSign为true,则hash会被视为原始消息.
|
||||
func TestSM2RequiresHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// 生成SM2密钥对
|
||||
keyPair, err := model.GenerateSM2KeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试数据
|
||||
originalData := []byte("test data for SM2 signing")
|
||||
|
||||
// 1. 直接对原始数据签名(当前实现的方式)
|
||||
// go-crpt 库会自动使用 SM3 计算摘要
|
||||
signature1, err := keyPair.SignMessage(originalData)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signature1)
|
||||
|
||||
// 2. 验证签名(使用原始数据)
|
||||
valid1, err := keyPair.VerifyMessage(originalData, signature1)
|
||||
require.NoError(t, err)
|
||||
t.Logf("直接使用原始数据签名和验证结果: %v", valid1)
|
||||
assert.True(t, valid1, "当前实现:直接对原始数据签名和验证应该成功")
|
||||
|
||||
// 3. 先hash再签名(文档推荐的方式)
|
||||
hashBytesReal := sha256.Sum256(originalData)
|
||||
|
||||
signature2, err := keyPair.SignMessage(hashBytesReal[:])
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signature2)
|
||||
|
||||
// 4. 验证签名(使用hash值)
|
||||
valid2, err := keyPair.VerifyMessage(hashBytesReal[:], signature2)
|
||||
require.NoError(t, err)
|
||||
t.Logf("先hash再签名和验证结果: %v", valid2)
|
||||
assert.True(t, valid2, "先hash再签名和验证应该成功")
|
||||
|
||||
// 5. 交叉验证:用原始数据验证hash后的签名
|
||||
valid3, err := keyPair.VerifyMessage(originalData, signature2)
|
||||
require.NoError(t, err)
|
||||
t.Logf("用原始数据验证hash后的签名结果: %v", valid3)
|
||||
|
||||
// 6. 交叉验证:用hash值验证原始数据的签名
|
||||
valid4, err := keyPair.VerifyMessage(hashBytesReal[:], signature1)
|
||||
require.NoError(t, err)
|
||||
t.Logf("用hash值验证原始数据的签名结果: %v", valid4)
|
||||
|
||||
// 结论:
|
||||
// - 如果valid1=true且valid4=false,说明SM2内部可能处理了hash,或者有某种兼容性
|
||||
// - 如果valid1=true且valid4=true,说明SM2可能接受原始数据(不符合文档)
|
||||
// - 如果valid1=false,说明SM2确实需要hash值
|
||||
|
||||
t.Logf("\n结论分析:")
|
||||
t.Logf("- 直接对原始数据签名和验证: %v", valid1)
|
||||
t.Logf("- 先hash再签名和验证: %v", valid2)
|
||||
t.Logf("- 交叉验证1(原始数据 vs hash签名): %v", valid3)
|
||||
t.Logf("- 交叉验证2(hash数据 vs 原始签名): %v", valid4)
|
||||
|
||||
switch {
|
||||
case valid1 && !valid4:
|
||||
t.Logf("✓ SM2库可能内部处理了hash,或者有兼容性机制")
|
||||
t.Logf("✓ 当前实现(直接使用原始数据)可能是可行的")
|
||||
case valid1 && valid4:
|
||||
t.Logf("⚠ SM2库可能接受原始数据,与文档不符")
|
||||
t.Logf("⚠ 但当前实现可以工作")
|
||||
default:
|
||||
t.Logf("✗ SM2确实需要hash值,当前实现可能有问题")
|
||||
}
|
||||
}
|
||||
13
api/model/trustlog.go
Normal file
13
api/model/trustlog.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package model
|
||||
|
||||
import "encoding"
|
||||
|
||||
// Trustlog 接口定义了信任日志的基本操作。
|
||||
// 实现了此接口的类型可以进行序列化、反序列化、哈希计算和提供生产者ID。
|
||||
type Trustlog interface {
|
||||
Hashable
|
||||
encoding.BinaryMarshaler
|
||||
encoding.BinaryUnmarshaler
|
||||
GetProducerID() string
|
||||
Key() string
|
||||
}
|
||||
32
api/model/validation.go
Normal file
32
api/model/validation.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package model
|
||||
|
||||
// Validation status codes.
|
||||
const (
|
||||
ValidationCodeProcessing = 100 // 处理中
|
||||
ValidationCodeCompleted = 200 // 完成
|
||||
ValidationCodeFailed = 500 // 失败
|
||||
)
|
||||
|
||||
// ValidationResult 包装取证的流式响应结果.
|
||||
type ValidationResult struct {
|
||||
Code int32 // 状态码(100处理中,200完成,500失败)
|
||||
Msg string // 消息描述
|
||||
Progress string // 当前进度(比如 "50%")
|
||||
Data *Operation // 最终完成时返回的操作数据,过程中可为空
|
||||
Proof *Proof // 取证证明(仅在完成时返回)
|
||||
}
|
||||
|
||||
// IsProcessing 判断是否正在处理中.
|
||||
func (v *ValidationResult) IsProcessing() bool {
|
||||
return v.Code == ValidationCodeProcessing
|
||||
}
|
||||
|
||||
// IsCompleted 判断是否已完成.
|
||||
func (v *ValidationResult) IsCompleted() bool {
|
||||
return v.Code == ValidationCodeCompleted
|
||||
}
|
||||
|
||||
// IsFailed 判断是否失败.
|
||||
func (v *ValidationResult) IsFailed() bool {
|
||||
return v.Code >= ValidationCodeFailed
|
||||
}
|
||||
238
api/model/validation_test.go
Normal file
238
api/model/validation_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/model"
|
||||
)
|
||||
|
||||
func TestValidationResult_IsProcessing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code int32
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "processing code",
|
||||
code: model.ValidationCodeProcessing,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "completed code",
|
||||
code: model.ValidationCodeCompleted,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "failed code",
|
||||
code: model.ValidationCodeFailed,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "other code",
|
||||
code: 99,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
v := &model.ValidationResult{Code: tt.code}
|
||||
assert.Equal(t, tt.expected, v.IsProcessing())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationResult_IsCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code int32
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "processing code",
|
||||
code: model.ValidationCodeProcessing,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completed code",
|
||||
code: model.ValidationCodeCompleted,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "failed code",
|
||||
code: model.ValidationCodeFailed,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "other code",
|
||||
code: 99,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
v := &model.ValidationResult{Code: tt.code}
|
||||
assert.Equal(t, tt.expected, v.IsCompleted())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationResult_IsFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code int32
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "processing code",
|
||||
code: model.ValidationCodeProcessing,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completed code",
|
||||
code: model.ValidationCodeCompleted,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "failed code",
|
||||
code: model.ValidationCodeFailed,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "code greater than failed",
|
||||
code: 501,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "code less than failed",
|
||||
code: 499,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
v := &model.ValidationResult{Code: tt.code}
|
||||
assert.Equal(t, tt.expected, v.IsFailed())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordValidationResult_IsProcessing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code int32
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "processing code",
|
||||
code: model.ValidationCodeProcessing,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "completed code",
|
||||
code: model.ValidationCodeCompleted,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "failed code",
|
||||
code: model.ValidationCodeFailed,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &model.RecordValidationResult{Code: tt.code}
|
||||
assert.Equal(t, tt.expected, r.IsProcessing())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordValidationResult_IsCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code int32
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "processing code",
|
||||
code: model.ValidationCodeProcessing,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completed code",
|
||||
code: model.ValidationCodeCompleted,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "failed code",
|
||||
code: model.ValidationCodeFailed,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &model.RecordValidationResult{Code: tt.code}
|
||||
assert.Equal(t, tt.expected, r.IsCompleted())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordValidationResult_IsFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code int32
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "processing code",
|
||||
code: model.ValidationCodeProcessing,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completed code",
|
||||
code: model.ValidationCodeCompleted,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "failed code",
|
||||
code: model.ValidationCodeFailed,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "code greater than failed",
|
||||
code: 501,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &model.RecordValidationResult{Code: tt.code}
|
||||
assert.Equal(t, tt.expected, r.IsFailed())
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user