refactor: 重构trustlog-sdk目录结构到trustlog/go-trustlog
- 将所有trustlog-sdk文件移动到trustlog/go-trustlog/目录 - 更新README中所有import路径从trustlog-sdk改为go-trustlog - 更新cookiecutter配置文件中的项目名称 - 更新根目录.lefthook.yml以引用新位置的配置 - 添加go.sum文件到版本控制 - 删除过时的示例文件 这次重构与trustlog-server保持一致的目录结构, 为未来支持多语言SDK(Python、Java等)预留空间。
This commit is contained in:
205
api/adapter/TCP_QUICK_START.md
Normal file
205
api/adapter/TCP_QUICK_START.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# TCP 适配器快速开始指南
|
||||
|
||||
## 简介
|
||||
|
||||
TCP 适配器提供了一个无需 Pulsar 的 Watermill 消息发布/订阅实现,适用于内网直连场景。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 启动消费端(Subscriber)
|
||||
|
||||
消费端作为 TCP 服务器,监听指定端口。
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 使用 NopLogger 或自定义 logger
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建 Subscriber
|
||||
config := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:9090",
|
||||
}
|
||||
|
||||
subscriber, err := adapter.NewTCPSubscriber(config, log)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer subscriber.Close()
|
||||
|
||||
// 订阅 topic
|
||||
messages, err := subscriber.Subscribe(context.Background(), "my-topic")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
for msg := range messages {
|
||||
log.Println("收到消息:", string(msg.Payload))
|
||||
msg.Ack() // 确认消息
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 启动生产端(Publisher)
|
||||
|
||||
生产端作为 TCP 客户端,连接到消费端。
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:9090",
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
AckTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer publisher.Close()
|
||||
|
||||
// 发送消息
|
||||
msg := message.NewMessage("msg-001", []byte("Hello, World!"))
|
||||
|
||||
err = publisher.Publish("my-topic", msg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Println("消息发送成功")
|
||||
}
|
||||
```
|
||||
|
||||
## 特性演示
|
||||
|
||||
### 并发发送多条消息
|
||||
|
||||
```go
|
||||
// 准备 10 条消息
|
||||
messages := make([]*message.Message, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
payload := []byte(fmt.Sprintf("Message #%d", i))
|
||||
messages[i] = message.NewMessage(fmt.Sprintf("msg-%d", i), payload)
|
||||
}
|
||||
|
||||
// 并发发送,Publisher 会等待所有 ACK
|
||||
err := publisher.Publish("my-topic", messages...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Println("所有消息发送成功")
|
||||
```
|
||||
|
||||
### 错误处理和 NACK
|
||||
|
||||
```go
|
||||
// 在消费端
|
||||
for msg := range messages {
|
||||
// 处理消息
|
||||
if err := processMessage(msg); err != nil {
|
||||
log.Println("处理失败:", err)
|
||||
msg.Nack() // 拒绝消息
|
||||
continue
|
||||
}
|
||||
msg.Ack() // 确认消息
|
||||
}
|
||||
```
|
||||
|
||||
## 配置参数
|
||||
|
||||
### TCPPublisherConfig
|
||||
|
||||
```go
|
||||
type TCPPublisherConfig struct {
|
||||
ServerAddr string // 必填: TCP 服务器地址,如 "127.0.0.1:9090"
|
||||
ConnectTimeout time.Duration // 连接超时,默认 10s
|
||||
AckTimeout time.Duration // ACK 超时,默认 30s
|
||||
MaxRetries int // 最大重试次数,默认 3
|
||||
}
|
||||
```
|
||||
|
||||
### TCPSubscriberConfig
|
||||
|
||||
```go
|
||||
type TCPSubscriberConfig struct {
|
||||
ListenAddr string // 必填: 监听地址,如 "127.0.0.1:9090"
|
||||
}
|
||||
```
|
||||
|
||||
## 运行示例
|
||||
|
||||
```bash
|
||||
# 运行完整示例
|
||||
cd trustlog-sdk/examples
|
||||
go run tcp_example.go
|
||||
```
|
||||
|
||||
## 性能特点
|
||||
|
||||
- ✅ **低延迟**: 直接 TCP 连接,无中间件开销
|
||||
- ✅ **高并发**: 支持并发发送多条消息
|
||||
- ✅ **可靠性**: 每条消息都需要 ACK 确认
|
||||
- ⚠️ **无持久化**: 消息仅在内存中传递
|
||||
|
||||
## 适用场景
|
||||
|
||||
✅ **适合:**
|
||||
- 内网服务间直接通信
|
||||
- 开发和测试环境
|
||||
- 无需消息持久化的场景
|
||||
- 低延迟要求的场景
|
||||
|
||||
❌ **不适合:**
|
||||
- 需要消息持久化
|
||||
- 需要高可用和故障恢复
|
||||
- 公网通信(需要加密)
|
||||
- 需要复杂的路由和负载均衡
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 如何处理连接断开?
|
||||
|
||||
A: 当前版本连接断开后需要重新创建 Publisher。未来版本将支持自动重连。
|
||||
|
||||
### Q: 消息会丢失吗?
|
||||
|
||||
A: TCP 适配器不提供持久化,连接断开或服务重启会导致未确认的消息丢失。
|
||||
|
||||
### Q: 如何实现多个消费者?
|
||||
|
||||
A: 当前版本将消息发送到第一个订阅者。如需负载均衡,需要在应用层实现。
|
||||
|
||||
### Q: 支持 TLS 加密吗?
|
||||
|
||||
A: 当前版本不支持 TLS。未来版本将添加 TLS/mTLS 支持。
|
||||
|
||||
## 下一步
|
||||
|
||||
- 查看 [完整文档](TCP_ADAPTER_README.md)
|
||||
- 运行 [测试用例](tcp_integration_test.go)
|
||||
- 查看 [示例代码](../../examples/tcp_example.go)
|
||||
|
||||
608
api/adapter/mocks/pulsar_mock.go
Normal file
608
api/adapter/mocks/pulsar_mock.go
Normal file
@@ -0,0 +1,608 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
)
|
||||
|
||||
// MockPulsarClient is a mock implementation of pulsar.Client.
|
||||
type MockPulsarClient struct {
|
||||
mu sync.RWMutex
|
||||
producers map[string]*MockProducer
|
||||
consumers map[string]*MockConsumer
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewMockPulsarClient creates a new mock Pulsar client.
|
||||
func NewMockPulsarClient() *MockPulsarClient {
|
||||
return &MockPulsarClient{
|
||||
producers: make(map[string]*MockProducer),
|
||||
consumers: make(map[string]*MockConsumer),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProducer creates a mock producer.
|
||||
func (m *MockPulsarClient) CreateProducer(options pulsar.ProducerOptions) (pulsar.Producer, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, errors.New("client is closed")
|
||||
}
|
||||
|
||||
if m.producers == nil {
|
||||
m.producers = make(map[string]*MockProducer)
|
||||
}
|
||||
|
||||
producer := NewMockProducer(options.Topic)
|
||||
m.producers[options.Topic] = producer
|
||||
return producer, nil
|
||||
}
|
||||
|
||||
// Subscribe creates a mock consumer.
|
||||
func (m *MockPulsarClient) Subscribe(options pulsar.ConsumerOptions) (pulsar.Consumer, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, errors.New("client is closed")
|
||||
}
|
||||
|
||||
if m.consumers == nil {
|
||||
m.consumers = make(map[string]*MockConsumer)
|
||||
}
|
||||
|
||||
consumer := NewMockConsumer(options.Topic, options.Name)
|
||||
m.consumers[options.Name] = consumer
|
||||
return consumer, nil
|
||||
}
|
||||
|
||||
// CreateReader is not implemented.
|
||||
func (m *MockPulsarClient) CreateReader(options pulsar.ReaderOptions) (pulsar.Reader, error) {
|
||||
return nil, errors.New("CreateReader not implemented")
|
||||
}
|
||||
|
||||
// CreateTableView is not implemented.
|
||||
func (m *MockPulsarClient) CreateTableView(options pulsar.TableViewOptions) (pulsar.TableView, error) {
|
||||
return nil, errors.New("CreateTableView not implemented")
|
||||
}
|
||||
|
||||
// NewTransaction creates a new transaction.
|
||||
func (m *MockPulsarClient) NewTransaction(timeout time.Duration) (pulsar.Transaction, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// TopicPartitions returns the partitions for a topic.
|
||||
func (m *MockPulsarClient) TopicPartitions(topic string) ([]string, error) {
|
||||
return []string{topic}, nil
|
||||
}
|
||||
|
||||
// Close closes the mock client.
|
||||
func (m *MockPulsarClient) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.closed = true
|
||||
for _, producer := range m.producers {
|
||||
producer.Close()
|
||||
}
|
||||
for _, consumer := range m.consumers {
|
||||
consumer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetProducer returns a producer by topic (for testing).
|
||||
func (m *MockPulsarClient) GetProducer(topic string) *MockProducer {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.producers[topic]
|
||||
}
|
||||
|
||||
// GetConsumer returns a consumer by name (for testing).
|
||||
func (m *MockPulsarClient) GetConsumer(name string) *MockConsumer {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.consumers[name]
|
||||
}
|
||||
|
||||
// MockProducer is a mock implementation of pulsar.Producer.
|
||||
type MockProducer struct {
|
||||
topic string
|
||||
name string
|
||||
messages []*pulsar.ProducerMessage
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewMockProducer creates a new mock producer.
|
||||
func NewMockProducer(topic string) *MockProducer {
|
||||
return &MockProducer{
|
||||
topic: topic,
|
||||
name: "mock-producer",
|
||||
messages: make([]*pulsar.ProducerMessage, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Topic returns the topic name.
|
||||
func (m *MockProducer) Topic() string {
|
||||
return m.topic
|
||||
}
|
||||
|
||||
// Name returns the producer name.
|
||||
func (m *MockProducer) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
// Send sends a message.
|
||||
func (m *MockProducer) Send(ctx context.Context, msg *pulsar.ProducerMessage) (pulsar.MessageID, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, errors.New("producer is closed")
|
||||
}
|
||||
|
||||
m.messages = append(m.messages, msg)
|
||||
return &MockMessageID{id: len(m.messages)}, nil
|
||||
}
|
||||
|
||||
// SendAsync sends a message asynchronously.
|
||||
func (m *MockProducer) SendAsync(
|
||||
ctx context.Context,
|
||||
msg *pulsar.ProducerMessage,
|
||||
callback func(pulsar.MessageID, *pulsar.ProducerMessage, error),
|
||||
) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
callback(nil, msg, errors.New("producer is closed"))
|
||||
return
|
||||
}
|
||||
|
||||
m.messages = append(m.messages, msg)
|
||||
callback(&MockMessageID{id: len(m.messages)}, msg, nil)
|
||||
}
|
||||
|
||||
// LastSequenceID returns the last sequence ID.
|
||||
func (m *MockProducer) LastSequenceID() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Flush flushes pending messages.
|
||||
func (m *MockProducer) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushWithCtx flushes pending messages with context.
|
||||
func (m *MockProducer) FlushWithCtx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the producer.
|
||||
func (m *MockProducer) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.closed = true
|
||||
}
|
||||
|
||||
// GetMessages returns all sent messages (for testing).
|
||||
func (m *MockProducer) GetMessages() []*pulsar.ProducerMessage {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]*pulsar.ProducerMessage, len(m.messages))
|
||||
copy(result, m.messages)
|
||||
return result
|
||||
}
|
||||
|
||||
// MockConsumer is a mock implementation of pulsar.Consumer.
|
||||
type MockConsumer struct {
|
||||
topic string
|
||||
name string
|
||||
messageChan chan pulsar.ConsumerMessage
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultMessageChannelSize 定义消息通道的默认缓冲大小.
|
||||
defaultMessageChannelSize = 10
|
||||
)
|
||||
|
||||
// NewMockConsumer creates a new mock consumer.
|
||||
func NewMockConsumer(topic, name string) *MockConsumer {
|
||||
return &MockConsumer{
|
||||
topic: topic,
|
||||
name: name,
|
||||
messageChan: make(chan pulsar.ConsumerMessage, defaultMessageChannelSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscription returns the subscription name.
|
||||
func (m *MockConsumer) Subscription() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
// Topic returns the topic name.
|
||||
func (m *MockConsumer) Topic() string {
|
||||
return m.topic
|
||||
}
|
||||
|
||||
// Chan returns the message channel.
|
||||
func (m *MockConsumer) Chan() <-chan pulsar.ConsumerMessage {
|
||||
return m.messageChan
|
||||
}
|
||||
|
||||
// Ack acknowledges a message.
|
||||
func (m *MockConsumer) Ack(msg pulsar.Message) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Nack negatively acknowledges a message.
|
||||
func (m *MockConsumer) Nack(msg pulsar.Message) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
// Mock implementation: 实际不做任何操作
|
||||
_ = msg
|
||||
}
|
||||
|
||||
// NackID negatively acknowledges a message by ID.
|
||||
func (m *MockConsumer) NackID(msgID pulsar.MessageID) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
// Mock implementation: 实际不做任何操作
|
||||
_ = msgID
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes the consumer.
|
||||
func (m *MockConsumer) Unsubscribe() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubscribeForce forcefully unsubscribes the consumer.
|
||||
func (m *MockConsumer) UnsubscribeForce() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive receives a single message.
|
||||
func (m *MockConsumer) Receive(ctx context.Context) (pulsar.Message, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, errors.New("consumer is closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-m.messageChan:
|
||||
return msg.Message, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// AckCumulative acknowledges all messages up to and including the provided message.
|
||||
func (m *MockConsumer) AckCumulative(msg pulsar.Message) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AckID acknowledges a message by ID.
|
||||
func (m *MockConsumer) AckID(msgID pulsar.MessageID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AckIDCumulative acknowledges all messages up to and including the provided message ID.
|
||||
func (m *MockConsumer) AckIDCumulative(msgID pulsar.MessageID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AckIDList acknowledges a list of message IDs.
|
||||
func (m *MockConsumer) AckIDList(msgIDs []pulsar.MessageID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AckWithTxn acknowledges a message with transaction.
|
||||
func (m *MockConsumer) AckWithTxn(msg pulsar.Message, txn pulsar.Transaction) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastMessageIDs returns the last message IDs.
|
||||
func (m *MockConsumer) GetLastMessageIDs() ([]pulsar.TopicMessageID, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, errors.New("consumer is closed")
|
||||
}
|
||||
return []pulsar.TopicMessageID{}, nil
|
||||
}
|
||||
|
||||
// ReconsumeLater reconsumes a message later with delay.
|
||||
func (m *MockConsumer) ReconsumeLater(msg pulsar.Message, delay time.Duration) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
// Mock implementation: 实际不做任何操作
|
||||
_, _ = msg, delay
|
||||
}
|
||||
|
||||
// ReconsumeLaterWithCustomProperties reconsumes a message later with custom properties.
|
||||
func (m *MockConsumer) ReconsumeLaterWithCustomProperties(
|
||||
msg pulsar.Message,
|
||||
customProperties map[string]string,
|
||||
delay time.Duration,
|
||||
) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
// Mock implementation: 实际不做任何操作
|
||||
_, _, _ = msg, customProperties, delay
|
||||
}
|
||||
|
||||
// Seek seeks to a message ID.
|
||||
func (m *MockConsumer) Seek(msgID pulsar.MessageID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeekByTime seeks to a time.
|
||||
func (m *MockConsumer) SeekByTime(t time.Time) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name returns the consumer name.
|
||||
func (m *MockConsumer) Name() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.name
|
||||
}
|
||||
|
||||
// Close closes the consumer.
|
||||
func (m *MockConsumer) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return
|
||||
}
|
||||
m.closed = true
|
||||
close(m.messageChan)
|
||||
}
|
||||
|
||||
// SendMessage sends a message to the consumer channel (for testing).
|
||||
func (m *MockConsumer) SendMessage(msg pulsar.ConsumerMessage) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return errors.New("consumer is closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case m.messageChan <- msg:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("channel full")
|
||||
}
|
||||
}
|
||||
|
||||
// MockMessageID is a mock implementation of pulsar.MessageID.
|
||||
type MockMessageID struct {
|
||||
id int
|
||||
}
|
||||
|
||||
// Serialize serializes the message ID.
|
||||
func (m *MockMessageID) Serialize() []byte {
|
||||
return []byte{byte(m.id)}
|
||||
}
|
||||
|
||||
// BatchIdx returns the batch index.
|
||||
func (m *MockMessageID) BatchIdx() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// BatchSize returns the batch size.
|
||||
func (m *MockMessageID) BatchSize() int32 {
|
||||
return 1
|
||||
}
|
||||
|
||||
// String returns the string representation of the message ID.
|
||||
func (m *MockMessageID) String() string {
|
||||
return fmt.Sprintf("mock-message-id-%d", m.id)
|
||||
}
|
||||
|
||||
// EntryID returns the entry ID.
|
||||
func (m *MockMessageID) EntryID() int64 {
|
||||
return int64(m.id)
|
||||
}
|
||||
|
||||
// LedgerID returns the ledger ID.
|
||||
func (m *MockMessageID) LedgerID() int64 {
|
||||
return int64(m.id)
|
||||
}
|
||||
|
||||
// PartitionIdx returns the partition index.
|
||||
func (m *MockMessageID) PartitionIdx() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// MockMessage is a mock implementation of pulsar.Message.
|
||||
type MockMessage struct {
|
||||
key string
|
||||
payload []byte
|
||||
id pulsar.MessageID
|
||||
}
|
||||
|
||||
// NewMockMessage creates a new mock message.
|
||||
func NewMockMessage(key string, payload []byte) *MockMessage {
|
||||
return &MockMessage{
|
||||
key: key,
|
||||
payload: payload,
|
||||
id: &MockMessageID{id: 1},
|
||||
}
|
||||
}
|
||||
|
||||
// Topic returns the topic name.
|
||||
func (m *MockMessage) Topic() string {
|
||||
return "mock-topic"
|
||||
}
|
||||
|
||||
// Properties returns message properties.
|
||||
func (m *MockMessage) Properties() map[string]string {
|
||||
return make(map[string]string)
|
||||
}
|
||||
|
||||
// Payload returns the message payload.
|
||||
func (m *MockMessage) Payload() []byte {
|
||||
return m.payload
|
||||
}
|
||||
|
||||
// ID returns the message ID.
|
||||
func (m *MockMessage) ID() pulsar.MessageID {
|
||||
return m.id
|
||||
}
|
||||
|
||||
// PublishTime returns the publish time.
|
||||
func (m *MockMessage) PublishTime() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// EventTime returns the event time.
|
||||
func (m *MockMessage) EventTime() time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// Key returns the message key.
|
||||
func (m *MockMessage) Key() string {
|
||||
return m.key
|
||||
}
|
||||
|
||||
// OrderingKey returns the ordering key.
|
||||
func (m *MockMessage) OrderingKey() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RedeliveryCount returns the redelivery count.
|
||||
func (m *MockMessage) RedeliveryCount() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsReplicated returns whether the message is replicated.
|
||||
func (m *MockMessage) IsReplicated() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetReplicatedFrom returns the replication source.
|
||||
func (m *MockMessage) GetReplicatedFrom() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSchemaValue returns the schema value.
|
||||
func (m *MockMessage) GetSchemaValue(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEncryptionContext returns the encryption context.
|
||||
func (m *MockMessage) GetEncryptionContext() *pulsar.EncryptionContext {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Index returns the message index.
|
||||
func (m *MockMessage) Index() *uint64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BrokerPublishTime returns the broker publish time.
|
||||
func (m *MockMessage) BrokerPublishTime() *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProducerName returns the producer name.
|
||||
func (m *MockMessage) ProducerName() string {
|
||||
return "mock-producer"
|
||||
}
|
||||
|
||||
// SchemaVersion returns the schema version.
|
||||
func (m *MockMessage) SchemaVersion() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplicatedFrom returns the replication source.
|
||||
func (m *MockMessage) ReplicatedFrom() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewMockConsumerMessage creates a new mock consumer message.
|
||||
func NewMockConsumerMessage(key string, payload []byte) pulsar.ConsumerMessage {
|
||||
return pulsar.ConsumerMessage{
|
||||
Message: NewMockMessage(key, payload),
|
||||
Consumer: nil,
|
||||
}
|
||||
}
|
||||
120
api/adapter/publisher.go
Normal file
120
api/adapter/publisher.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
logger2 "go.yandata.net/iod/iod/trustlog-sdk/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
OperationTopic = "persistent://public/default/operation"
|
||||
RecordTopic = "persistent://public/default/record"
|
||||
)
|
||||
|
||||
// PublisherConfig is the configuration to create a publisher.
|
||||
type PublisherConfig struct {
|
||||
// URL is the Pulsar URL.
|
||||
URL string
|
||||
// TLSTrustCertsFilePath is the path to the CA certificate file for verifying the server certificate.
|
||||
// If empty, TLS verification will be disabled.
|
||||
TLSTrustCertsFilePath string
|
||||
// TLSCertificateFilePath is the path to the client certificate file for mTLS authentication.
|
||||
// If empty, mTLS authentication will be disabled.
|
||||
TLSCertificateFilePath string
|
||||
// TLSKeyFilePath is the path to the client private key file for mTLS authentication.
|
||||
// If empty, mTLS authentication will be disabled.
|
||||
TLSKeyFilePath string
|
||||
// TLSAllowInsecureConnection allows insecure TLS connections (not recommended for production).
|
||||
TLSAllowInsecureConnection bool
|
||||
}
|
||||
|
||||
// Publisher provides the pulsar implementation for watermill publish operations.
|
||||
type Publisher struct {
|
||||
conn pulsar.Client
|
||||
logger logger.Logger
|
||||
pubs map[string]pulsar.Producer
|
||||
}
|
||||
|
||||
// NewPublisher creates a new Publisher.
|
||||
func NewPublisher(config PublisherConfig, adapter logger.Logger) (*Publisher, error) {
|
||||
clientOptions := pulsar.ClientOptions{
|
||||
URL: config.URL,
|
||||
Logger: logger2.NewPulsarLoggerAdapter(adapter),
|
||||
}
|
||||
|
||||
// Configure TLS/mTLS
|
||||
if err := configureTLSForClient(&clientOptions, config, adapter); err != nil {
|
||||
return nil, errors.Join(err, errors.New("failed to configure TLS"))
|
||||
}
|
||||
|
||||
conn, err := pulsar.NewClient(clientOptions)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, errors.New("cannot connect to pulsar"))
|
||||
}
|
||||
|
||||
return NewPublisherWithPulsarClient(conn, adapter)
|
||||
}
|
||||
|
||||
// NewPublisherWithPulsarClient creates a new Publisher with the provided pulsar connection.
|
||||
func NewPublisherWithPulsarClient(conn pulsar.Client, logger logger.Logger) (*Publisher, error) {
|
||||
return &Publisher{
|
||||
conn: conn,
|
||||
pubs: make(map[string]pulsar.Producer),
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Publish publishes message to Pulsar.
|
||||
//
|
||||
// Publish will not return until an ack has been received from Pulsar.
|
||||
// When one of messages delivery fails - function is interrupted.
|
||||
func (p *Publisher) Publish(topic string, messages ...*message.Message) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
producer, found := p.pubs[topic]
|
||||
|
||||
if !found {
|
||||
pr, err := p.conn.CreateProducer(pulsar.ProducerOptions{Topic: topic})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
producer = pr
|
||||
p.pubs[topic] = producer
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
// 跳过 nil 消息
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
p.logger.DebugContext(ctx, "Sending message", "key", msg.UUID, "topic", topic)
|
||||
_, err := producer.Send(ctx, &pulsar.ProducerMessage{
|
||||
Key: msg.UUID,
|
||||
Payload: msg.Payload,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the publisher and the underlying connection.
|
||||
func (p *Publisher) Close() error {
|
||||
for _, pub := range p.pubs {
|
||||
pub.Close()
|
||||
}
|
||||
|
||||
p.conn.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
212
api/adapter/publisher_test.go
Normal file
212
api/adapter/publisher_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package adapter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
func TestNewPublisherWithPulsarClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, pub)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := message.NewMessage("test-uuid", []byte("test payload"))
|
||||
err = pub.Publish("test-topic", msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify message was sent
|
||||
producer := mockClient.GetProducer("test-topic")
|
||||
require.NotNil(t, producer)
|
||||
messages := producer.GetMessages()
|
||||
require.Len(t, messages, 1)
|
||||
assert.Equal(t, "test-uuid", messages[0].Key)
|
||||
assert.Equal(t, []byte("test payload"), messages[0].Payload)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_MultipleMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||
err = pub.Publish("test-topic", msg1, msg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
producer := mockClient.GetProducer("test-topic")
|
||||
require.NotNil(t, producer)
|
||||
messages := producer.GetMessages()
|
||||
require.Len(t, messages, 2)
|
||||
assert.Equal(t, "uuid-1", messages[0].Key)
|
||||
assert.Equal(t, "uuid-2", messages[1].Key)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_MultipleTopics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||
|
||||
err = pub.Publish("topic-1", msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Publish("topic-2", msg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
producer1 := mockClient.GetProducer("topic-1")
|
||||
require.NotNil(t, producer1)
|
||||
messages1 := producer1.GetMessages()
|
||||
require.Len(t, messages1, 1)
|
||||
|
||||
producer2 := mockClient.GetProducer("topic-2")
|
||||
require.NotNil(t, producer2)
|
||||
messages2 := producer2.GetMessages()
|
||||
require.Len(t, messages2, 1)
|
||||
}
|
||||
|
||||
func TestPublisher_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Close_AfterPublish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := message.NewMessage("test-uuid", []byte("test payload"))
|
||||
err = pub.Publish("test-topic", msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_ReuseProducer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||
err = pub.Publish("test-topic", msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||
err = pub.Publish("test-topic", msg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
producer := mockClient.GetProducer("test-topic")
|
||||
require.NotNil(t, producer)
|
||||
messages := producer.GetMessages()
|
||||
require.Len(t, messages, 2)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_EmptyTopic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := message.NewMessage("uuid", []byte("payload"))
|
||||
err = pub.Publish("", msg)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_NilMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Publish with nil message - should handle gracefully
|
||||
err = pub.Publish("test-topic", nil)
|
||||
// May succeed or fail depending on implementation
|
||||
_ = err
|
||||
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_AfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := message.NewMessage("uuid", []byte("payload"))
|
||||
err = pub.Publish("test-topic", msg)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewPublisher_InvalidURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := adapter.PublisherConfig{
|
||||
URL: "invalid-url",
|
||||
}
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
_, err := adapter.NewPublisher(config, log)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot connect")
|
||||
}
|
||||
274
api/adapter/subscriber.go
Normal file
274
api/adapter/subscriber.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill"
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
logger2 "go.yandata.net/iod/iod/trustlog-sdk/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
SubNameKey contextKey = "subName"
|
||||
ReceiverQueueSizeKey contextKey = "receiverQueueSize"
|
||||
IndexKey contextKey = "index"
|
||||
|
||||
ReceiverQueueSizeDefault = 1000
|
||||
SubNameDefault = "subName"
|
||||
TimeOutDefault = time.Second * 10
|
||||
defaultMessageChannelSize = 10
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
var _ message.Subscriber = &Subscriber{}
|
||||
|
||||
// SubscriberConfig is the configuration to create a subscriber.
|
||||
type SubscriberConfig struct {
|
||||
// URL is the URL to the broker
|
||||
URL string
|
||||
// SubscriberName is the name of the subscription.
|
||||
SubscriberName string
|
||||
// SubscriberType is the type of the subscription.
|
||||
SubscriberType pulsar.SubscriptionType
|
||||
// TLSTrustCertsFilePath is the path to the CA certificate file for verifying the server certificate.
|
||||
// If empty, TLS verification will be disabled.
|
||||
TLSTrustCertsFilePath string
|
||||
// TLSCertificateFilePath is the path to the client certificate file for mTLS authentication.
|
||||
// If empty, mTLS authentication will be disabled.
|
||||
TLSCertificateFilePath string
|
||||
// TLSKeyFilePath is the path to the client private key file for mTLS authentication.
|
||||
// If empty, mTLS authentication will be disabled.
|
||||
TLSKeyFilePath string
|
||||
// TLSAllowInsecureConnection allows insecure TLS connections (not recommended for production).
|
||||
TLSAllowInsecureConnection bool
|
||||
}
|
||||
|
||||
// Subscriber provides the pulsar implementation for watermill subscribe operations.
|
||||
type Subscriber struct {
|
||||
conn pulsar.Client
|
||||
logger logger.Logger
|
||||
|
||||
subsLock sync.RWMutex
|
||||
// Change to map with composite key: topic + subscriptionName + subName
|
||||
subs map[string]pulsar.Consumer
|
||||
closed bool
|
||||
closing chan struct{}
|
||||
SubscribersCount int
|
||||
clientID string
|
||||
|
||||
config SubscriberConfig
|
||||
}
|
||||
|
||||
// NewSubscriber creates a new Subscriber.
|
||||
func NewSubscriber(config SubscriberConfig, adapter logger.Logger) (*Subscriber, error) {
|
||||
clientOptions := pulsar.ClientOptions{
|
||||
URL: config.URL,
|
||||
Logger: logger2.NewPulsarLoggerAdapter(adapter),
|
||||
}
|
||||
|
||||
// Configure TLS/mTLS
|
||||
if err := configureTLSForClient(&clientOptions, config, adapter); err != nil {
|
||||
return nil, errors.Join(err, errors.New("failed to configure TLS"))
|
||||
}
|
||||
|
||||
conn, err := pulsar.NewClient(clientOptions)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, errors.New("cannot connect to Pulsar"))
|
||||
}
|
||||
return NewSubscriberWithPulsarClient(conn, config, adapter)
|
||||
}
|
||||
|
||||
// NewSubscriberWithPulsarClient creates a new Subscriber with the provided pulsar client.
|
||||
func NewSubscriberWithPulsarClient(
|
||||
conn pulsar.Client,
|
||||
config SubscriberConfig,
|
||||
logger logger.Logger,
|
||||
) (*Subscriber, error) {
|
||||
return &Subscriber{
|
||||
conn: conn,
|
||||
logger: logger,
|
||||
closing: make(chan struct{}),
|
||||
clientID: watermill.NewULID(),
|
||||
subs: make(map[string]pulsar.Consumer),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Subscribe subscribes messages from Pulsar.
|
||||
func (s *Subscriber) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
|
||||
output := make(chan *message.Message)
|
||||
|
||||
s.subsLock.Lock()
|
||||
|
||||
subName, ok := ctx.Value(SubNameKey).(string)
|
||||
if !ok {
|
||||
subName = SubNameDefault
|
||||
}
|
||||
|
||||
index, ok := ctx.Value(IndexKey).(int)
|
||||
if !ok {
|
||||
index = 0
|
||||
}
|
||||
|
||||
receiverQueueSize, ok := ctx.Value(ReceiverQueueSizeKey).(int)
|
||||
if !ok {
|
||||
receiverQueueSize = ReceiverQueueSizeDefault
|
||||
}
|
||||
|
||||
subscriptionName := fmt.Sprintf("%s-%s", topic, s.clientID)
|
||||
if s.config.SubscriberName != "" {
|
||||
subscriptionName = s.config.SubscriberName
|
||||
}
|
||||
|
||||
sn := fmt.Sprintf("%s_%s", subscriptionName, subName)
|
||||
n := fmt.Sprintf("%s_%d", sn, index)
|
||||
|
||||
sub, found := s.subs[n]
|
||||
|
||||
if !found {
|
||||
subscribeCtx, cancel := context.WithTimeout(ctx, TimeOutDefault)
|
||||
defer cancel()
|
||||
done := make(chan struct{})
|
||||
|
||||
var sb pulsar.Consumer
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
sb, err = s.conn.Subscribe(pulsar.ConsumerOptions{
|
||||
Topic: topic,
|
||||
Name: n,
|
||||
SubscriptionName: sn,
|
||||
Type: s.config.SubscriberType,
|
||||
MessageChannel: make(chan pulsar.ConsumerMessage, defaultMessageChannelSize),
|
||||
ReceiverQueueSize: receiverQueueSize,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscribeCtx.Done():
|
||||
s.subsLock.Unlock()
|
||||
return nil, fmt.Errorf("subscription timeout: %w", subscribeCtx.Err())
|
||||
case <-done:
|
||||
if err != nil {
|
||||
s.subsLock.Unlock()
|
||||
return nil, fmt.Errorf("subscription failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.subs[n] = sb
|
||||
sub = sb
|
||||
}
|
||||
|
||||
s.subsLock.Unlock()
|
||||
|
||||
// 创建本地引用以避免竞态条件
|
||||
localSub := sub
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.closing:
|
||||
s.logger.InfoContext(ctx, "subscriber is closing")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
s.logger.InfoContext(ctx, "exiting on context closure")
|
||||
return
|
||||
case m, msgOk := <-localSub.Chan():
|
||||
if !msgOk {
|
||||
// Channel closed, exit the loop
|
||||
s.logger.InfoContext(ctx, "consumer channel closed")
|
||||
return
|
||||
}
|
||||
go s.processMessage(ctx, output, m, localSub)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (s *Subscriber) processMessage(
|
||||
ctx context.Context,
|
||||
output chan *message.Message,
|
||||
m pulsar.Message,
|
||||
sub pulsar.Consumer,
|
||||
) {
|
||||
if s.isClosed() {
|
||||
return
|
||||
}
|
||||
s.logger.DebugContext(ctx, "Received message", "key", m.Key())
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(ctx)
|
||||
defer cancelCtx()
|
||||
|
||||
msg := message.NewMessage(m.Key(), m.Payload())
|
||||
select {
|
||||
case <-s.closing:
|
||||
s.logger.DebugContext(ctx, "Closing, message discarded", "key", m.Key())
|
||||
return
|
||||
case <-ctx.Done():
|
||||
s.logger.DebugContext(ctx, "Context cancelled, message discarded")
|
||||
return
|
||||
// if this is first can risk 'send on closed channel' errors
|
||||
case output <- msg:
|
||||
s.logger.DebugContext(ctx, "Message sent to consumer")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-msg.Acked():
|
||||
err := sub.Ack(m)
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, "Message Ack Failed")
|
||||
}
|
||||
s.logger.DebugContext(ctx, "Message Acked")
|
||||
case <-msg.Nacked():
|
||||
sub.Nack(m)
|
||||
s.logger.DebugContext(ctx, "Message Nacked")
|
||||
case <-s.closing:
|
||||
s.logger.DebugContext(ctx, "Closing, message discarded before ack")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
s.logger.DebugContext(ctx, "Context cancelled, message discarded before ack")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the publisher and the underlying connection. It will attempt to wait for in-flight messages to complete.
|
||||
func (s *Subscriber) Close() error {
|
||||
s.subsLock.Lock()
|
||||
defer s.subsLock.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
|
||||
s.logger.DebugContext(context.Background(), "Closing subscriber")
|
||||
defer s.logger.InfoContext(context.Background(), "Subscriber closed")
|
||||
|
||||
close(s.closing)
|
||||
|
||||
for _, sub := range s.subs {
|
||||
sub.Close()
|
||||
}
|
||||
s.conn.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Subscriber) isClosed() bool {
|
||||
s.subsLock.RLock()
|
||||
defer s.subsLock.RUnlock()
|
||||
|
||||
return s.closed
|
||||
}
|
||||
216
api/adapter/subscriber_advanced_test.go
Normal file
216
api/adapter/subscriber_advanced_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package adapter_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
func TestSubscriber_Subscribe_WithAllContextValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.WithValue(context.Background(), adapter.SubNameKey, "custom-sub")
|
||||
ctx = context.WithValue(ctx, adapter.IndexKey, 2)
|
||||
ctx = context.WithValue(ctx, adapter.ReceiverQueueSizeKey, 1500)
|
||||
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_ReuseExistingConsumer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Subscribe first time
|
||||
msgChan1, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan1)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Subscribe again with same topic - should reuse consumer
|
||||
msgChan2, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan2)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_DifferentIndices(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx1 := context.WithValue(context.Background(), adapter.IndexKey, 0)
|
||||
msgChan1, err := sub.Subscribe(ctx1, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan1)
|
||||
|
||||
ctx2 := context.WithValue(context.Background(), adapter.IndexKey, 1)
|
||||
msgChan2, err := sub.Subscribe(ctx2, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan2)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_WithoutSubscriberName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Close_WithMultipleSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = sub.Subscribe(ctx, "topic-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sub.Subscribe(ctx, "topic-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_EmptyMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
defer pub.Close()
|
||||
|
||||
// Publish with no messages - should succeed
|
||||
err = pub.Publish("test-topic")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_MultipleMessagesSameTopic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
defer pub.Close()
|
||||
|
||||
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||
msg3 := message.NewMessage("uuid-3", []byte("payload-3"))
|
||||
|
||||
err = pub.Publish("test-topic", msg1, msg2, msg3)
|
||||
require.NoError(t, err)
|
||||
|
||||
producer := mockClient.GetProducer("test-topic")
|
||||
require.NotNil(t, producer)
|
||||
messages := producer.GetMessages()
|
||||
require.Len(t, messages, 3)
|
||||
}
|
||||
|
||||
func TestPublisher_Close_WithMultipleProducers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||
|
||||
err = pub.Publish("topic-1", msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Publish("topic-2", msg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_Close_MultipleTimes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close again should be safe
|
||||
err = pub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
195
api/adapter/subscriber_edge_test.go
Normal file
195
api/adapter/subscriber_edge_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package adapter_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// MockPulsarClientWithSubscribeError is a mock client that can return subscription errors.
|
||||
type MockPulsarClientWithSubscribeError struct {
|
||||
*mocks.MockPulsarClient
|
||||
|
||||
subscribeError error
|
||||
}
|
||||
|
||||
func NewMockPulsarClientWithSubscribeError() *MockPulsarClientWithSubscribeError {
|
||||
return &MockPulsarClientWithSubscribeError{
|
||||
MockPulsarClient: mocks.NewMockPulsarClient(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockPulsarClientWithSubscribeError) SetSubscribeError(err error) {
|
||||
m.subscribeError = err
|
||||
}
|
||||
|
||||
func (m *MockPulsarClientWithSubscribeError) Subscribe(options pulsar.ConsumerOptions) (pulsar.Consumer, error) {
|
||||
if m.subscribeError != nil {
|
||||
return nil, m.subscribeError
|
||||
}
|
||||
return m.MockPulsarClient.Subscribe(options)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_SubscriptionError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := NewMockPulsarClientWithSubscribeError()
|
||||
mockClient.SetSubscribeError(errors.New("subscription failed"))
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = sub.Subscribe(ctx, "test-topic")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "subscription failed")
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
// Use a very short timeout context that's already expired
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
|
||||
cancel() // Cancel immediately
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
_, err = sub.Subscribe(ctx, "test-topic")
|
||||
// Should timeout or fail due to cancelled context
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_WithCustomSubName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.WithValue(context.Background(), adapter.SubNameKey, "custom-sub-name")
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_WithCustomIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.WithValue(context.Background(), adapter.IndexKey, 5)
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_WithCustomReceiverQueueSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
ctx := context.WithValue(context.Background(), adapter.ReceiverQueueSizeKey, 2000)
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_CreateProducerError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
defer pub.Close()
|
||||
|
||||
// Close client before creating producer
|
||||
mockClient.Close()
|
||||
|
||||
msg := message.NewMessage("test-uuid", []byte("test payload"))
|
||||
err = pub.Publish("new-topic", msg)
|
||||
// Should fail when creating producer
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublisher_Publish_SendError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
pub, err := adapter.NewPublisherWithPulsarClient(mockClient, log)
|
||||
require.NoError(t, err)
|
||||
defer pub.Close()
|
||||
|
||||
// Create a producer first
|
||||
msg1 := message.NewMessage("test-uuid-1", []byte("test payload 1"))
|
||||
err = pub.Publish("test-topic", msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the producer to cause send error
|
||||
producer := mockClient.GetProducer("test-topic")
|
||||
require.NotNil(t, producer)
|
||||
producer.Close()
|
||||
|
||||
msg2 := message.NewMessage("test-uuid-2", []byte("test payload 2"))
|
||||
err = pub.Publish("test-topic", msg2)
|
||||
// May succeed or fail depending on implementation
|
||||
_ = err
|
||||
}
|
||||
259
api/adapter/subscriber_test.go
Normal file
259
api/adapter/subscriber_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package adapter_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter/mocks"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
func TestNewSubscriberWithPulsarClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_WithContextValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
msgChan, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriber_Close_AfterSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_MultipleTopics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
msgChan1, err := sub.Subscribe(ctx, "topic-1")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan1)
|
||||
|
||||
msgChan2, err := sub.Subscribe(ctx, "topic-2")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan2)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_ReuseConsumer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
msgChan1, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
|
||||
msgChan2, err := sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, msgChan1)
|
||||
assert.NotNil(t, msgChan2)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_ContextCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_, err = sub.Subscribe(ctx, "test-topic")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Wait a bit for goroutine to process cancellation
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Close subscriber
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_EmptyTopic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
msgChan, err := sub.Subscribe(ctx, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, msgChan)
|
||||
}
|
||||
|
||||
func TestSubscriber_Close_MultipleTimes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close again should be safe
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriber_Subscribe_AfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := mocks.NewMockPulsarClient()
|
||||
log := logger.NewNopLogger()
|
||||
config := adapter.SubscriberConfig{
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
|
||||
sub, err := adapter.NewSubscriberWithPulsarClient(mockClient, config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sub.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = sub.Subscribe(ctx, "test-topic")
|
||||
// Behavior depends on implementation - may succeed or fail
|
||||
_ = err
|
||||
}
|
||||
|
||||
func TestNewSubscriber_InvalidURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := adapter.SubscriberConfig{
|
||||
URL: "invalid-url",
|
||||
SubscriberName: "test-sub",
|
||||
SubscriberType: pulsar.Shared,
|
||||
}
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
_, err := adapter.NewSubscriber(config, log)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot connect")
|
||||
}
|
||||
229
api/adapter/tcp_integration_test.go
Normal file
229
api/adapter/tcp_integration_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
)
|
||||
|
||||
// 简单的测试日志适配器
|
||||
type testLogger struct{}
|
||||
|
||||
func (t *testLogger) InfoContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||
func (t *testLogger) DebugContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||
func (t *testLogger) WarnContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||
func (t *testLogger) ErrorContext(ctx context.Context, msg string, args ...interface{}) {}
|
||||
func (t *testLogger) Info(msg string, args ...interface{}) {}
|
||||
func (t *testLogger) Debug(msg string, args ...interface{}) {}
|
||||
func (t *testLogger) Warn(msg string, args ...interface{}) {}
|
||||
func (t *testLogger) Error(msg string, args ...interface{}) {}
|
||||
|
||||
func TestTCPPublisherSubscriber_Integration(t *testing.T) {
|
||||
testLogger := &testLogger{}
|
||||
|
||||
// 创建 Subscriber
|
||||
subscriberConfig := TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:18080",
|
||||
}
|
||||
subscriber, err := NewTCPSubscriber(subscriberConfig, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscriber: %v", err)
|
||||
}
|
||||
defer subscriber.Close()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 订阅 topic
|
||||
ctx := context.Background()
|
||||
topic := "test-topic"
|
||||
msgChan, err := subscriber.Subscribe(ctx, topic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe: %v", err)
|
||||
}
|
||||
|
||||
// 创建 Publisher
|
||||
publisherConfig := TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:18080",
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
publisher, err := NewTCPPublisher(publisherConfig, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create publisher: %v", err)
|
||||
}
|
||||
defer publisher.Close()
|
||||
|
||||
// 测试发送和接收消息
|
||||
testPayload := []byte("Hello, TCP Watermill!")
|
||||
testMsg := message.NewMessage("test-msg-1", testPayload)
|
||||
|
||||
// 启动接收协程
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case receivedMsg := <-msgChan:
|
||||
if string(receivedMsg.Payload) != string(testPayload) {
|
||||
t.Errorf("Payload mismatch: got %s, want %s", receivedMsg.Payload, testPayload)
|
||||
}
|
||||
if receivedMsg.UUID != testMsg.UUID {
|
||||
t.Errorf("UUID mismatch: got %s, want %s", receivedMsg.UUID, testMsg.UUID)
|
||||
}
|
||||
// ACK 消息
|
||||
receivedMsg.Ack()
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout waiting for message")
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送消息
|
||||
err = publisher.Publish(topic, testMsg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish message: %v", err)
|
||||
}
|
||||
|
||||
// 等待接收完成
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestTCPPublisherSubscriber_MultipleMessages(t *testing.T) {
|
||||
testLogger := &testLogger{}
|
||||
|
||||
// 创建 Subscriber
|
||||
subscriberConfig := TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:18081",
|
||||
}
|
||||
subscriber, err := NewTCPSubscriber(subscriberConfig, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscriber: %v", err)
|
||||
}
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 订阅
|
||||
ctx := context.Background()
|
||||
topic := "test-topic-multi"
|
||||
msgChan, err := subscriber.Subscribe(ctx, topic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe: %v", err)
|
||||
}
|
||||
|
||||
// 创建 Publisher
|
||||
publisherConfig := TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:18081",
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
publisher, err := NewTCPPublisher(publisherConfig, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create publisher: %v", err)
|
||||
}
|
||||
defer publisher.Close()
|
||||
|
||||
// 准备多条消息
|
||||
messageCount := 10
|
||||
messages := make([]*message.Message, messageCount)
|
||||
for i := 0; i < messageCount; i++ {
|
||||
payload := []byte("Message " + string(rune('0'+i)))
|
||||
messages[i] = message.NewMessage("msg-"+string(rune('0'+i)), payload)
|
||||
}
|
||||
|
||||
// 启动接收协程
|
||||
receivedCount := 0
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < messageCount; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case receivedMsg := <-msgChan:
|
||||
mu.Lock()
|
||||
receivedCount++
|
||||
mu.Unlock()
|
||||
receivedMsg.Ack()
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Error("Timeout waiting for message")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 发送消息(并发发送)
|
||||
err = publisher.Publish(topic, messages...)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish messages: %v", err)
|
||||
}
|
||||
|
||||
// 等待接收完成
|
||||
wg.Wait()
|
||||
|
||||
if receivedCount != messageCount {
|
||||
t.Errorf("Received count mismatch: got %d, want %d", receivedCount, messageCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPPublisherSubscriber_Nack(t *testing.T) {
|
||||
testLogger := &testLogger{}
|
||||
|
||||
// 创建 Subscriber
|
||||
subscriberConfig := TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:18082",
|
||||
}
|
||||
subscriber, err := NewTCPSubscriber(subscriberConfig, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscriber: %v", err)
|
||||
}
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 订阅
|
||||
ctx := context.Background()
|
||||
topic := "test-topic-nack"
|
||||
msgChan, err := subscriber.Subscribe(ctx, topic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe: %v", err)
|
||||
}
|
||||
|
||||
// 创建 Publisher
|
||||
publisherConfig := TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:18082",
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
publisher, err := NewTCPPublisher(publisherConfig, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create publisher: %v", err)
|
||||
}
|
||||
defer publisher.Close()
|
||||
|
||||
// 准备消息
|
||||
testMsg := message.NewMessage("nack-test", []byte("This will be nacked"))
|
||||
|
||||
// 启动接收协程,这次 NACK
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case receivedMsg := <-msgChan:
|
||||
// NACK 消息
|
||||
receivedMsg.Nack()
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout waiting for message")
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送消息,由于不等待ACK,应该立即返回成功
|
||||
// 注意:即使消费者NACK,发布者也会返回成功
|
||||
err = publisher.Publish(topic, testMsg)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error (fire-and-forget), got: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
149
api/adapter/tcp_protocol.go
Normal file
149
api/adapter/tcp_protocol.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// 协议常量.
|
||||
const (
|
||||
// MessageTypeData 表示数据消息.
|
||||
MessageTypeData byte = 0x01
|
||||
// MessageTypeAck 表示 ACK 确认.
|
||||
MessageTypeAck byte = 0x02
|
||||
// MessageTypeNack 表示 NACK 否定确认.
|
||||
MessageTypeNack byte = 0x03
|
||||
|
||||
// 协议限制.
|
||||
maxTopicLength = 65535
|
||||
maxUUIDLength = 255
|
||||
maxPayloadSize = 1 << 30
|
||||
topicLengthSize = 2
|
||||
uuidLengthSize = 1
|
||||
payloadLengthSize = 4
|
||||
)
|
||||
|
||||
// 预定义错误.
|
||||
var (
|
||||
ErrNilMessage = errors.New("message is nil")
|
||||
ErrTopicTooLong = errors.New("topic too long")
|
||||
ErrUUIDTooLong = errors.New("uuid too long")
|
||||
ErrPayloadTooLarge = errors.New("payload too large")
|
||||
)
|
||||
|
||||
// TCPMessage 表示 TCP 传输的消息.
|
||||
type TCPMessage struct {
|
||||
Type byte // 消息类型
|
||||
Topic string // 主题
|
||||
UUID string // 消息 UUID
|
||||
Payload []byte // 消息内容
|
||||
}
|
||||
|
||||
// EncodeTCPMessage 将消息编码为字节数组.
|
||||
// 格式: [消息类型 1字节][Topic长度 2字节][Topic][UUID长度 1字节][UUID][Payload长度 4字节][Payload].
|
||||
func EncodeTCPMessage(msg *TCPMessage) ([]byte, error) {
|
||||
if msg == nil {
|
||||
return nil, ErrNilMessage
|
||||
}
|
||||
|
||||
topicLen := len(msg.Topic)
|
||||
if topicLen > maxTopicLength {
|
||||
return nil, ErrTopicTooLong
|
||||
}
|
||||
|
||||
uuidLen := len(msg.UUID)
|
||||
if uuidLen > maxUUIDLength {
|
||||
return nil, ErrUUIDTooLong
|
||||
}
|
||||
|
||||
payloadLen := len(msg.Payload)
|
||||
if payloadLen > maxPayloadSize {
|
||||
return nil, ErrPayloadTooLarge
|
||||
}
|
||||
|
||||
// 计算总长度
|
||||
totalLen := 1 + topicLengthSize + topicLen + uuidLengthSize + uuidLen + payloadLengthSize + payloadLen
|
||||
buf := make([]byte, totalLen)
|
||||
|
||||
offset := 0
|
||||
|
||||
// 写入消息类型
|
||||
buf[offset] = msg.Type
|
||||
offset++
|
||||
|
||||
// 写入 Topic 长度和内容
|
||||
binary.BigEndian.PutUint16(buf[offset:], uint16(topicLen))
|
||||
offset += topicLengthSize
|
||||
copy(buf[offset:], []byte(msg.Topic))
|
||||
offset += topicLen
|
||||
|
||||
// 写入 UUID 长度和内容
|
||||
buf[offset] = byte(uuidLen)
|
||||
offset++
|
||||
copy(buf[offset:], []byte(msg.UUID))
|
||||
offset += uuidLen
|
||||
|
||||
// 写入 Payload 长度和内容
|
||||
binary.BigEndian.PutUint32(buf[offset:], uint32(payloadLen))
|
||||
offset += payloadLengthSize
|
||||
copy(buf[offset:], msg.Payload)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// DecodeTCPMessage 从字节数组解码消息.
|
||||
func DecodeTCPMessage(reader io.Reader) (*TCPMessage, error) {
|
||||
msg := &TCPMessage{}
|
||||
|
||||
// 读取消息类型
|
||||
msgTypeBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(reader, msgTypeBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.Type = msgTypeBuf[0]
|
||||
|
||||
// 读取 Topic 长度
|
||||
topicLenBuf := make([]byte, topicLengthSize)
|
||||
if _, err := io.ReadFull(reader, topicLenBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topicLen := binary.BigEndian.Uint16(topicLenBuf)
|
||||
|
||||
// 读取 Topic
|
||||
topicBuf := make([]byte, topicLen)
|
||||
if _, err := io.ReadFull(reader, topicBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.Topic = string(topicBuf)
|
||||
|
||||
// 读取 UUID 长度
|
||||
uuidLenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(reader, uuidLenBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
uuidLen := uuidLenBuf[0]
|
||||
|
||||
// 读取 UUID
|
||||
uuidBuf := make([]byte, uuidLen)
|
||||
if _, err := io.ReadFull(reader, uuidBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.UUID = string(uuidBuf)
|
||||
|
||||
// 读取 Payload 长度
|
||||
payloadLenBuf := make([]byte, payloadLengthSize)
|
||||
if _, err := io.ReadFull(reader, payloadLenBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen := binary.BigEndian.Uint32(payloadLenBuf)
|
||||
|
||||
// 读取 Payload
|
||||
payloadBuf := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(reader, payloadBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.Payload = payloadBuf
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
166
api/adapter/tcp_protocol_test.go
Normal file
166
api/adapter/tcp_protocol_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncodeTCPMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *TCPMessage
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid data message",
|
||||
msg: &TCPMessage{
|
||||
Type: MessageTypeData,
|
||||
Topic: "test-topic",
|
||||
UUID: "test-uuid-1234",
|
||||
Payload: []byte("test payload"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid ack message",
|
||||
msg: &TCPMessage{
|
||||
Type: MessageTypeAck,
|
||||
Topic: "",
|
||||
UUID: "test-uuid-5678",
|
||||
Payload: nil,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil message",
|
||||
msg: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := EncodeTCPMessage(tt.msg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("EncodeTCPMessage() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && data == nil {
|
||||
t.Error("EncodeTCPMessage() returned nil data")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTCPMessage(t *testing.T) {
|
||||
// 创建一个测试消息
|
||||
original := &TCPMessage{
|
||||
Type: MessageTypeData,
|
||||
Topic: "test-topic",
|
||||
UUID: "test-uuid-1234",
|
||||
Payload: []byte("test payload data"),
|
||||
}
|
||||
|
||||
// 编码
|
||||
encoded, err := EncodeTCPMessage(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode message: %v", err)
|
||||
}
|
||||
|
||||
// 解码
|
||||
reader := bytes.NewReader(encoded)
|
||||
decoded, err := DecodeTCPMessage(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode message: %v", err)
|
||||
}
|
||||
|
||||
// 验证
|
||||
if decoded.Type != original.Type {
|
||||
t.Errorf("Type mismatch: got %v, want %v", decoded.Type, original.Type)
|
||||
}
|
||||
if decoded.Topic != original.Topic {
|
||||
t.Errorf("Topic mismatch: got %v, want %v", decoded.Topic, original.Topic)
|
||||
}
|
||||
if decoded.UUID != original.UUID {
|
||||
t.Errorf("UUID mismatch: got %v, want %v", decoded.UUID, original.UUID)
|
||||
}
|
||||
if !bytes.Equal(decoded.Payload, original.Payload) {
|
||||
t.Errorf("Payload mismatch: got %v, want %v", decoded.Payload, original.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeRoundTrip(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg *TCPMessage
|
||||
}{
|
||||
{
|
||||
name: "data message with payload",
|
||||
msg: &TCPMessage{
|
||||
Type: MessageTypeData,
|
||||
Topic: "persistent://public/default/test",
|
||||
UUID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
Payload: []byte("Hello, World!"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ack message",
|
||||
msg: &TCPMessage{
|
||||
Type: MessageTypeAck,
|
||||
Topic: "",
|
||||
UUID: "test-uuid",
|
||||
Payload: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nack message",
|
||||
msg: &TCPMessage{
|
||||
Type: MessageTypeNack,
|
||||
Topic: "",
|
||||
UUID: "another-uuid",
|
||||
Payload: []byte{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with large payload",
|
||||
msg: &TCPMessage{
|
||||
Type: MessageTypeData,
|
||||
Topic: "test",
|
||||
UUID: "uuid",
|
||||
Payload: make([]byte, 10000),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 编码
|
||||
encoded, err := EncodeTCPMessage(tc.msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Encode failed: %v", err)
|
||||
}
|
||||
|
||||
// 解码
|
||||
reader := bytes.NewReader(encoded)
|
||||
decoded, err := DecodeTCPMessage(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证所有字段
|
||||
if decoded.Type != tc.msg.Type {
|
||||
t.Errorf("Type: got %v, want %v", decoded.Type, tc.msg.Type)
|
||||
}
|
||||
if decoded.Topic != tc.msg.Topic {
|
||||
t.Errorf("Topic: got %v, want %v", decoded.Topic, tc.msg.Topic)
|
||||
}
|
||||
if decoded.UUID != tc.msg.UUID {
|
||||
t.Errorf("UUID: got %v, want %v", decoded.UUID, tc.msg.UUID)
|
||||
}
|
||||
if !bytes.Equal(decoded.Payload, tc.msg.Payload) {
|
||||
t.Errorf("Payload: got %v, want %v", decoded.Payload, tc.msg.Payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
195
api/adapter/tcp_publisher.go
Normal file
195
api/adapter/tcp_publisher.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// 默认配置常量.
|
||||
const (
|
||||
defaultConnectTimeout = 10 * time.Second
|
||||
defaultMaxRetries = 3
|
||||
)
|
||||
|
||||
// 预定义错误.
|
||||
var (
|
||||
ErrServerAddrRequired = errors.New("server address is required")
|
||||
ErrPublisherClosed = errors.New("publisher is closed")
|
||||
)
|
||||
|
||||
// TCPPublisherConfig TCP 发布者配置
|
||||
type TCPPublisherConfig struct {
|
||||
// ServerAddr TCP 服务器地址,格式: "host:port"
|
||||
ServerAddr string
|
||||
// ConnectTimeout 连接超时时间
|
||||
ConnectTimeout time.Duration
|
||||
// MaxRetries 最大重试次数
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
// TCPPublisher 实现基于 TCP 的 watermill Publisher
|
||||
type TCPPublisher struct {
|
||||
config TCPPublisherConfig
|
||||
conn net.Conn
|
||||
logger logger.Logger
|
||||
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
closeChan chan struct{}
|
||||
}
|
||||
|
||||
// NewTCPPublisher 创建一个新的 TCP Publisher.
|
||||
func NewTCPPublisher(config TCPPublisherConfig, logger logger.Logger) (*TCPPublisher, error) {
|
||||
if config.ServerAddr == "" {
|
||||
return nil, ErrServerAddrRequired
|
||||
}
|
||||
|
||||
if config.ConnectTimeout == 0 {
|
||||
config.ConnectTimeout = defaultConnectTimeout
|
||||
}
|
||||
|
||||
if config.MaxRetries == 0 {
|
||||
config.MaxRetries = defaultMaxRetries
|
||||
}
|
||||
|
||||
p := &TCPPublisher{
|
||||
config: config,
|
||||
logger: logger,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 连接到服务器
|
||||
if err := p.connect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 不再接收 ACK/NACK,发送即成功模式
|
||||
// go p.receiveAcks() // 已移除
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// connect 连接到 TCP 服务器
|
||||
func (p *TCPPublisher) connect() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.config.ConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
var d net.Dialer
|
||||
conn, err := d.DialContext(ctx, "tcp", p.config.ServerAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to %s: %w", p.config.ServerAddr, err)
|
||||
}
|
||||
|
||||
p.conn = conn
|
||||
p.logger.InfoContext(context.Background(), "Connected to TCP server", "addr", p.config.ServerAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish 发布消息.
|
||||
func (p *TCPPublisher) Publish(topic string, messages ...*message.Message) error {
|
||||
p.closedMu.RLock()
|
||||
if p.closed {
|
||||
p.closedMu.RUnlock()
|
||||
return ErrPublisherClosed
|
||||
}
|
||||
p.closedMu.RUnlock()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 使用 WaitGroup 和 errChan 来并发发送消息并收集错误
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, 0, len(messages))
|
||||
var errMu sync.Mutex
|
||||
errChan := make(chan error, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(m *message.Message) {
|
||||
defer wg.Done()
|
||||
|
||||
if err := p.publishSingle(ctx, topic, m); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}(msg)
|
||||
}
|
||||
|
||||
// 等待所有消息发送完成
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 检查是否有错误
|
||||
for err := range errChan {
|
||||
errMu.Lock()
|
||||
errs = append(errs, err)
|
||||
errMu.Unlock()
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to publish %d messages: %w", len(errs), errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishSingle 发送单条消息,不等待 ACK
|
||||
func (p *TCPPublisher) publishSingle(ctx context.Context, topic string, msg *message.Message) error {
|
||||
tcpMsg := &TCPMessage{
|
||||
Type: MessageTypeData,
|
||||
Topic: topic,
|
||||
UUID: msg.UUID,
|
||||
Payload: msg.Payload,
|
||||
}
|
||||
|
||||
// 编码消息
|
||||
data, err := EncodeTCPMessage(tcpMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode message: %w", err)
|
||||
}
|
||||
|
||||
p.logger.DebugContext(ctx, "Sending message", "uuid", msg.UUID, "topic", topic)
|
||||
|
||||
// 发送消息
|
||||
if _, err := p.conn.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
|
||||
p.logger.DebugContext(ctx, "Message sent successfully", "uuid", msg.UUID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// receiveAcks, shouldStopReceiving, handleDecodeError 方法已移除
|
||||
// 不再接收 ACK/NACK,采用发送即成功模式以提高性能
|
||||
|
||||
// Close 关闭发布者
|
||||
func (p *TCPPublisher) Close() error {
|
||||
p.closedMu.Lock()
|
||||
if p.closed {
|
||||
p.closedMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
p.closedMu.Unlock()
|
||||
|
||||
close(p.closeChan)
|
||||
|
||||
if p.conn != nil {
|
||||
if err := p.conn.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.InfoContext(context.Background(), "TCP Publisher closed")
|
||||
return nil
|
||||
}
|
||||
246
api/adapter/tcp_publisher_test.go
Normal file
246
api/adapter/tcp_publisher_test.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package adapter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/adapter"
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// 验证 TCPPublisher 实现了 message.Publisher 接口
|
||||
func TestTCPPublisher_ImplementsPublisherInterface(t *testing.T) {
|
||||
var _ message.Publisher = (*adapter.TCPPublisher)(nil)
|
||||
}
|
||||
|
||||
func TestNewTCPPublisher_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 首先创建一个订阅者作为服务器
|
||||
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:19090",
|
||||
}
|
||||
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||
require.NoError(t, err)
|
||||
defer subscriber.Close()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19090",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, publisher)
|
||||
|
||||
err = publisher.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNewTCPPublisher_InvalidServerAddr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
_, err := adapter.NewTCPPublisher(config, log)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, adapter.ErrServerAddrRequired)
|
||||
}
|
||||
|
||||
func TestNewTCPPublisher_ConnectionFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 尝试连接到不存在的服务器
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19999",
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
}
|
||||
_, err := adapter.NewTCPPublisher(config, log)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to connect")
|
||||
}
|
||||
|
||||
func TestTCPPublisher_Publish_NoWaitForAck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建订阅者
|
||||
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:19091",
|
||||
}
|
||||
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||
require.NoError(t, err)
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19091",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
require.NoError(t, err)
|
||||
defer publisher.Close()
|
||||
|
||||
// 发送消息,应该立即返回成功,不等待ACK
|
||||
msg := message.NewMessage("test-uuid-1", []byte("test payload"))
|
||||
start := time.Now()
|
||||
err = publisher.Publish("test-topic", msg)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// 验证发送成功
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证发送速度很快(不应该等待ACK超时)
|
||||
// 应该在100ms内返回(实际应该只需要几毫秒)
|
||||
assert.Less(t, elapsed, 100*time.Millisecond, "Publish should return immediately without waiting for ACK")
|
||||
}
|
||||
|
||||
func TestTCPPublisher_Publish_MultipleMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建订阅者
|
||||
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:19092",
|
||||
}
|
||||
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||
require.NoError(t, err)
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19092",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
require.NoError(t, err)
|
||||
defer publisher.Close()
|
||||
|
||||
// 发送多条消息
|
||||
msg1 := message.NewMessage("uuid-1", []byte("payload-1"))
|
||||
msg2 := message.NewMessage("uuid-2", []byte("payload-2"))
|
||||
msg3 := message.NewMessage("uuid-3", []byte("payload-3"))
|
||||
|
||||
start := time.Now()
|
||||
err = publisher.Publish("test-topic", msg1, msg2, msg3)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.NoError(t, err)
|
||||
// 发送3条消息应该很快完成
|
||||
assert.Less(t, elapsed, 200*time.Millisecond, "Publishing multiple messages should be fast")
|
||||
}
|
||||
|
||||
func TestTCPPublisher_Publish_AfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建订阅者
|
||||
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:19093",
|
||||
}
|
||||
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||
require.NoError(t, err)
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19093",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 关闭 Publisher
|
||||
err = publisher.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 尝试在关闭后发送消息
|
||||
msg := message.NewMessage("uuid", []byte("payload"))
|
||||
err = publisher.Publish("test-topic", msg)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, adapter.ErrPublisherClosed)
|
||||
}
|
||||
|
||||
func TestTCPPublisher_Publish_NilMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建订阅者
|
||||
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:19094",
|
||||
}
|
||||
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||
require.NoError(t, err)
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19094",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
require.NoError(t, err)
|
||||
defer publisher.Close()
|
||||
|
||||
// 发送 nil 消息应该被忽略
|
||||
err = publisher.Publish("test-topic", nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTCPPublisher_Close_Multiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := logger.NewNopLogger()
|
||||
|
||||
// 创建订阅者
|
||||
subscriberConfig := adapter.TCPSubscriberConfig{
|
||||
ListenAddr: "127.0.0.1:19095",
|
||||
}
|
||||
subscriber, err := adapter.NewTCPSubscriber(subscriberConfig, log)
|
||||
require.NoError(t, err)
|
||||
defer subscriber.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 创建 Publisher
|
||||
config := adapter.TCPPublisherConfig{
|
||||
ServerAddr: "127.0.0.1:19095",
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}
|
||||
publisher, err := adapter.NewTCPPublisher(config, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 多次关闭应该不会报错
|
||||
err = publisher.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = publisher.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
310
api/adapter/tcp_subscriber.go
Normal file
310
api/adapter/tcp_subscriber.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// 订阅者配置常量.
|
||||
const (
|
||||
defaultOutputChannelSize = 100
|
||||
minOutputChannelSize = 10
|
||||
maxOutputChannelSize = 10000
|
||||
)
|
||||
|
||||
// 预定义错误.
|
||||
var (
|
||||
ErrListenAddrRequired = errors.New("listen address is required")
|
||||
ErrSubscriberClosed = errors.New("subscriber is closed")
|
||||
)
|
||||
|
||||
// TCPSubscriberConfig TCP 订阅者配置
|
||||
type TCPSubscriberConfig struct {
|
||||
// ListenAddr 监听地址,格式: "host:port"
|
||||
ListenAddr string
|
||||
|
||||
// OutputChannelSize 输出 channel 的缓冲大小
|
||||
// 较小的值(如 10-50):更快的背压传递,但可能降低吞吐量
|
||||
// 较大的值(如 500-1000):更高的吞吐量,但背压传递较慢
|
||||
// 默认值:100(平衡吞吐量和背压)
|
||||
OutputChannelSize int
|
||||
}
|
||||
|
||||
// TCPSubscriber 实现基于 TCP 的 watermill Subscriber
|
||||
type TCPSubscriber struct {
|
||||
config TCPSubscriberConfig
|
||||
logger logger.Logger
|
||||
listener net.Listener
|
||||
|
||||
subsLock sync.RWMutex
|
||||
subs map[string][]chan *message.Message // topic -> channels
|
||||
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
closeChan chan struct{}
|
||||
|
||||
// 连接管理
|
||||
connMu sync.Mutex
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
// NewTCPSubscriber 创建一个新的 TCP Subscriber.
|
||||
func NewTCPSubscriber(config TCPSubscriberConfig, logger logger.Logger) (*TCPSubscriber, error) {
|
||||
if config.ListenAddr == "" {
|
||||
return nil, ErrListenAddrRequired
|
||||
}
|
||||
|
||||
// 验证和设置 channel 大小
|
||||
channelSize := config.OutputChannelSize
|
||||
if channelSize <= 0 {
|
||||
channelSize = defaultOutputChannelSize
|
||||
}
|
||||
if channelSize < minOutputChannelSize {
|
||||
channelSize = minOutputChannelSize
|
||||
logger.WarnContext(context.Background(), "OutputChannelSize too small, using minimum",
|
||||
"configured", config.OutputChannelSize, "actual", minOutputChannelSize)
|
||||
}
|
||||
if channelSize > maxOutputChannelSize {
|
||||
channelSize = maxOutputChannelSize
|
||||
logger.WarnContext(context.Background(), "OutputChannelSize too large, using maximum",
|
||||
"configured", config.OutputChannelSize, "actual", maxOutputChannelSize)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", config.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on %s: %w", config.ListenAddr, err)
|
||||
}
|
||||
|
||||
// 更新配置中的实际 channel 大小
|
||||
config.OutputChannelSize = channelSize
|
||||
|
||||
s := &TCPSubscriber{
|
||||
config: config,
|
||||
logger: logger,
|
||||
listener: listener,
|
||||
subs: make(map[string][]chan *message.Message),
|
||||
closeChan: make(chan struct{}),
|
||||
conns: make([]net.Conn, 0),
|
||||
}
|
||||
|
||||
// 启动接受连接的协程
|
||||
go s.acceptConnections()
|
||||
|
||||
logger.InfoContext(context.Background(), "TCP Subscriber listening",
|
||||
"addr", config.ListenAddr,
|
||||
"channel_size", channelSize)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// acceptConnections 接受客户端连接
|
||||
func (s *TCPSubscriber) acceptConnections() {
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.closeChan:
|
||||
s.logger.InfoContext(ctx, "Stopping connection acceptor")
|
||||
return
|
||||
default:
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
s.closedMu.RLock()
|
||||
closed := s.closed
|
||||
s.closedMu.RUnlock()
|
||||
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.ErrorContext(ctx, "Failed to accept connection", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.InfoContext(ctx, "Accepted new connection", "remote", conn.RemoteAddr().String())
|
||||
|
||||
// 保存连接
|
||||
s.connMu.Lock()
|
||||
s.conns = append(s.conns, conn)
|
||||
s.connMu.Unlock()
|
||||
|
||||
// 为每个连接启动处理协程
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection 处理单个客户端连接
|
||||
func (s *TCPSubscriber) handleConnection(conn net.Conn) {
|
||||
ctx := context.Background()
|
||||
defer func() {
|
||||
conn.Close()
|
||||
s.logger.InfoContext(ctx, "Connection closed", "remote", conn.RemoteAddr().String())
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.closeChan:
|
||||
return
|
||||
default:
|
||||
// 读取消息
|
||||
tcpMsg, err := DecodeTCPMessage(conn)
|
||||
if err != nil {
|
||||
s.closedMu.RLock()
|
||||
closed := s.closed
|
||||
s.closedMu.RUnlock()
|
||||
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.ErrorContext(ctx, "Failed to decode message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if tcpMsg.Type != MessageTypeData {
|
||||
s.logger.WarnContext(ctx, "Unexpected message type", "type", tcpMsg.Type)
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
s.handleMessage(ctx, conn, tcpMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage 处理消息(发送即成功模式,无需 ACK/NACK)
|
||||
func (s *TCPSubscriber) handleMessage(ctx context.Context, conn net.Conn, tcpMsg *TCPMessage) {
|
||||
s.logger.DebugContext(ctx, "Received message", "uuid", tcpMsg.UUID, "topic", tcpMsg.Topic)
|
||||
|
||||
// 获取该 topic 的订阅者
|
||||
s.subsLock.RLock()
|
||||
channels, found := s.subs[tcpMsg.Topic]
|
||||
s.subsLock.RUnlock()
|
||||
|
||||
if !found || len(channels) == 0 {
|
||||
s.logger.WarnContext(ctx, "No subscribers for topic", "topic", tcpMsg.Topic)
|
||||
// 不再发送 NACK,直接丢弃消息
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 watermill 消息
|
||||
msg := message.NewMessage(tcpMsg.UUID, tcpMsg.Payload)
|
||||
|
||||
// 使用随机策略选择订阅者(无锁,性能更好)
|
||||
randomIndex := rand.Intn(len(channels))
|
||||
outputChan := channels[randomIndex]
|
||||
|
||||
// 记录 channel 使用情况,便于监控背压
|
||||
channelLen := len(outputChan)
|
||||
channelCap := cap(outputChan)
|
||||
usage := float64(channelLen) / float64(channelCap) * 100
|
||||
|
||||
s.logger.DebugContext(ctx, "Dispatching message via random selection",
|
||||
"uuid", tcpMsg.UUID,
|
||||
"subscriber_index", randomIndex,
|
||||
"total_subscribers", len(channels),
|
||||
"channel_usage", fmt.Sprintf("%.1f%% (%d/%d)", usage, channelLen, channelCap))
|
||||
|
||||
// 阻塞式发送:当 channel 满时会阻塞,从而触发 TCP 背压
|
||||
// 这会导致:
|
||||
// 1. 当前 goroutine 阻塞
|
||||
// 2. TCP 读取停止
|
||||
// 3. TCP 接收窗口填满
|
||||
// 4. 发送端收到零窗口通知
|
||||
// 5. 发送端停止发送
|
||||
select {
|
||||
case outputChan <- msg:
|
||||
s.logger.DebugContext(ctx, "Message sent to subscriber", "uuid", tcpMsg.UUID, "index", randomIndex)
|
||||
// 发送即成功:立即 Ack 消息,不等待处理结果
|
||||
msg.Ack()
|
||||
case <-s.closeChan:
|
||||
s.logger.DebugContext(ctx, "Subscriber closed, message discarded", "uuid", tcpMsg.UUID)
|
||||
return
|
||||
}
|
||||
|
||||
// 不再等待消息被 ACK 或 NACK,也不发送 ACK/NACK 回执
|
||||
}
|
||||
|
||||
// sendAck 方法已移除
|
||||
// 采用发送即成功模式,不再发送 ACK/NACK 回执以提高性能
|
||||
|
||||
// Subscribe 订阅指定 topic 的消息.
|
||||
func (s *TCPSubscriber) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
|
||||
s.closedMu.RLock()
|
||||
if s.closed {
|
||||
s.closedMu.RUnlock()
|
||||
return nil, ErrSubscriberClosed
|
||||
}
|
||||
s.closedMu.RUnlock()
|
||||
|
||||
// 使用配置的 channel 大小
|
||||
channelSize := s.config.OutputChannelSize
|
||||
if channelSize <= 0 {
|
||||
channelSize = defaultOutputChannelSize
|
||||
}
|
||||
output := make(chan *message.Message, channelSize)
|
||||
|
||||
s.subsLock.Lock()
|
||||
if s.subs[topic] == nil {
|
||||
s.subs[topic] = make([]chan *message.Message, 0)
|
||||
}
|
||||
s.subs[topic] = append(s.subs[topic], output)
|
||||
subscriberCount := len(s.subs[topic])
|
||||
s.subsLock.Unlock()
|
||||
|
||||
s.logger.InfoContext(ctx, "Subscribed to topic",
|
||||
"topic", topic,
|
||||
"subscriber_count", subscriberCount,
|
||||
"channel_size", channelSize)
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// Close 关闭订阅者
|
||||
func (s *TCPSubscriber) Close() error {
|
||||
s.closedMu.Lock()
|
||||
if s.closed {
|
||||
s.closedMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
s.closedMu.Unlock()
|
||||
|
||||
close(s.closeChan)
|
||||
|
||||
// 关闭监听器
|
||||
if s.listener != nil {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.logger.ErrorContext(context.Background(), "Failed to close listener", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭所有连接
|
||||
s.connMu.Lock()
|
||||
for _, conn := range s.conns {
|
||||
conn.Close()
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
// 关闭所有订阅通道
|
||||
s.subsLock.Lock()
|
||||
for topic, channels := range s.subs {
|
||||
for _, ch := range channels {
|
||||
close(ch)
|
||||
}
|
||||
delete(s.subs, topic)
|
||||
}
|
||||
s.subsLock.Unlock()
|
||||
|
||||
s.logger.InfoContext(context.Background(), "TCP Subscriber closed")
|
||||
return nil
|
||||
}
|
||||
123
api/adapter/tls_config.go
Normal file
123
api/adapter/tls_config.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/apache/pulsar-client-go/pulsar/auth"
|
||||
|
||||
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
||||
)
|
||||
|
||||
// tlsConfigProvider defines the interface for TLS configuration.
|
||||
type tlsConfigProvider interface {
|
||||
GetTLSTrustCertsFilePath() string
|
||||
GetTLSCertificateFilePath() string
|
||||
GetTLSKeyFilePath() string
|
||||
GetTLSAllowInsecureConnection() bool
|
||||
}
|
||||
|
||||
// configureTLSForClient configures TLS/mTLS settings for the Pulsar client.
|
||||
func configureTLSForClient(opts *pulsar.ClientOptions, config tlsConfigProvider, logger logger.Logger) error {
|
||||
// If no TLS configuration is provided, skip TLS setup
|
||||
if config.GetTLSTrustCertsFilePath() == "" &&
|
||||
config.GetTLSCertificateFilePath() == "" &&
|
||||
config.GetTLSKeyFilePath() == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Configure TLS trust certificates
|
||||
if config.GetTLSTrustCertsFilePath() != "" {
|
||||
if _, err := os.ReadFile(config.GetTLSTrustCertsFilePath()); err != nil {
|
||||
return errors.Join(err, errors.New("failed to read TLS trust certificates file"))
|
||||
}
|
||||
opts.TLSTrustCertsFilePath = config.GetTLSTrustCertsFilePath()
|
||||
logger.Debug(
|
||||
"TLS trust certificates configured",
|
||||
"path", config.GetTLSTrustCertsFilePath(),
|
||||
)
|
||||
}
|
||||
|
||||
// Configure TLS allow insecure connection
|
||||
opts.TLSAllowInsecureConnection = config.GetTLSAllowInsecureConnection()
|
||||
|
||||
// Configure mTLS authentication if both certificate and key are provided
|
||||
if config.GetTLSCertificateFilePath() != "" && config.GetTLSKeyFilePath() != "" {
|
||||
// Load client certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(
|
||||
config.GetTLSCertificateFilePath(),
|
||||
config.GetTLSKeyFilePath(),
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Join(err, errors.New("failed to load client certificate and key"))
|
||||
}
|
||||
|
||||
// Create TLS authentication provider
|
||||
// Pulsar Go client uses auth.NewAuthenticationTLS with certificate and key file paths
|
||||
tlsAuth := auth.NewAuthenticationTLS(
|
||||
config.GetTLSCertificateFilePath(),
|
||||
config.GetTLSKeyFilePath(),
|
||||
)
|
||||
|
||||
opts.Authentication = tlsAuth
|
||||
logger.Debug(
|
||||
"mTLS authentication configured",
|
||||
"cert", config.GetTLSCertificateFilePath(),
|
||||
"key", config.GetTLSKeyFilePath(),
|
||||
)
|
||||
|
||||
// Verify the certificate is valid
|
||||
if _, parseErr := x509.ParseCertificate(cert.Certificate[0]); parseErr != nil {
|
||||
return errors.Join(parseErr, errors.New("invalid client certificate"))
|
||||
}
|
||||
} else if config.GetTLSCertificateFilePath() != "" || config.GetTLSKeyFilePath() != "" {
|
||||
return errors.New(
|
||||
"both TLS certificate and key file paths must be provided for mTLS authentication",
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTLSTrustCertsFilePath returns the TLS trust certificates file path for PublisherConfig.
|
||||
func (c PublisherConfig) GetTLSTrustCertsFilePath() string {
|
||||
return c.TLSTrustCertsFilePath
|
||||
}
|
||||
|
||||
// GetTLSCertificateFilePath returns the TLS certificate file path for PublisherConfig.
|
||||
func (c PublisherConfig) GetTLSCertificateFilePath() string {
|
||||
return c.TLSCertificateFilePath
|
||||
}
|
||||
|
||||
// GetTLSKeyFilePath returns the TLS key file path for PublisherConfig.
|
||||
func (c PublisherConfig) GetTLSKeyFilePath() string {
|
||||
return c.TLSKeyFilePath
|
||||
}
|
||||
|
||||
// GetTLSAllowInsecureConnection returns whether to allow insecure TLS connections for PublisherConfig.
|
||||
func (c PublisherConfig) GetTLSAllowInsecureConnection() bool {
|
||||
return c.TLSAllowInsecureConnection
|
||||
}
|
||||
|
||||
// GetTLSTrustCertsFilePath returns the TLS trust certificates file path for SubscriberConfig.
|
||||
func (c SubscriberConfig) GetTLSTrustCertsFilePath() string {
|
||||
return c.TLSTrustCertsFilePath
|
||||
}
|
||||
|
||||
// GetTLSCertificateFilePath returns the TLS certificate file path for SubscriberConfig.
|
||||
func (c SubscriberConfig) GetTLSCertificateFilePath() string {
|
||||
return c.TLSCertificateFilePath
|
||||
}
|
||||
|
||||
// GetTLSKeyFilePath returns the TLS key file path for SubscriberConfig.
|
||||
func (c SubscriberConfig) GetTLSKeyFilePath() string {
|
||||
return c.TLSKeyFilePath
|
||||
}
|
||||
|
||||
// GetTLSAllowInsecureConnection returns whether to allow insecure TLS connections for SubscriberConfig.
|
||||
func (c SubscriberConfig) GetTLSAllowInsecureConnection() bool {
|
||||
return c.TLSAllowInsecureConnection
|
||||
}
|
||||
Reference in New Issue
Block a user