package adapter import ( "bytes" "testing" ) func TestEncodeTCPMessage(t *testing.T) { tests := []struct { name string msg *TCPMessage wantErr bool }{ { name: "valid data message", msg: &TCPMessage{ Type: MessageTypeData, Topic: "test-topic", UUID: "test-uuid-1234", Payload: []byte("test payload"), }, wantErr: false, }, { name: "valid ack message", msg: &TCPMessage{ Type: MessageTypeAck, Topic: "", UUID: "test-uuid-5678", Payload: nil, }, wantErr: false, }, { name: "nil message", msg: nil, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { data, err := EncodeTCPMessage(tt.msg) if (err != nil) != tt.wantErr { t.Errorf("EncodeTCPMessage() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr && data == nil { t.Error("EncodeTCPMessage() returned nil data") } }) } } func TestDecodeTCPMessage(t *testing.T) { // 创建一个测试消息 original := &TCPMessage{ Type: MessageTypeData, Topic: "test-topic", UUID: "test-uuid-1234", Payload: []byte("test payload data"), } // 编码 encoded, err := EncodeTCPMessage(original) if err != nil { t.Fatalf("Failed to encode message: %v", err) } // 解码 reader := bytes.NewReader(encoded) decoded, err := DecodeTCPMessage(reader) if err != nil { t.Fatalf("Failed to decode message: %v", err) } // 验证 if decoded.Type != original.Type { t.Errorf("Type mismatch: got %v, want %v", decoded.Type, original.Type) } if decoded.Topic != original.Topic { t.Errorf("Topic mismatch: got %v, want %v", decoded.Topic, original.Topic) } if decoded.UUID != original.UUID { t.Errorf("UUID mismatch: got %v, want %v", decoded.UUID, original.UUID) } if !bytes.Equal(decoded.Payload, original.Payload) { t.Errorf("Payload mismatch: got %v, want %v", decoded.Payload, original.Payload) } } func TestEncodeDecodeRoundTrip(t *testing.T) { testCases := []struct { name string msg *TCPMessage }{ { name: "data message with payload", msg: &TCPMessage{ Type: MessageTypeData, Topic: "persistent://public/default/test", UUID: "550e8400-e29b-41d4-a716-446655440000", Payload: []byte("Hello, World!"), }, }, { name: "ack message", msg: &TCPMessage{ Type: MessageTypeAck, Topic: "", UUID: "test-uuid", Payload: nil, }, }, { name: "nack message", msg: &TCPMessage{ Type: MessageTypeNack, Topic: "", UUID: "another-uuid", Payload: []byte{}, }, }, { name: "message with large payload", msg: &TCPMessage{ Type: MessageTypeData, Topic: "test", UUID: "uuid", Payload: make([]byte, 10000), }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // 编码 encoded, err := EncodeTCPMessage(tc.msg) if err != nil { t.Fatalf("Encode failed: %v", err) } // 解码 reader := bytes.NewReader(encoded) decoded, err := DecodeTCPMessage(reader) if err != nil { t.Fatalf("Decode failed: %v", err) } // 验证所有字段 if decoded.Type != tc.msg.Type { t.Errorf("Type: got %v, want %v", decoded.Type, tc.msg.Type) } if decoded.Topic != tc.msg.Topic { t.Errorf("Topic: got %v, want %v", decoded.Topic, tc.msg.Topic) } if decoded.UUID != tc.msg.UUID { t.Errorf("UUID: got %v, want %v", decoded.UUID, tc.msg.UUID) } if !bytes.Equal(decoded.Payload, tc.msg.Payload) { t.Errorf("Payload: got %v, want %v", decoded.Payload, tc.msg.Payload) } }) } }