package adapter import ( "context" "errors" "fmt" "net" "sync" "time" "github.com/ThreeDotsLabs/watermill/message" "go.yandata.net/iod/iod/trustlog-sdk/api/logger" ) // 默认配置常量. const ( defaultConnectTimeout = 10 * time.Second defaultMaxRetries = 3 ) // 预定义错误. var ( ErrServerAddrRequired = errors.New("server address is required") ErrPublisherClosed = errors.New("publisher is closed") ) // TCPPublisherConfig TCP 发布者配置 type TCPPublisherConfig struct { // ServerAddr TCP 服务器地址,格式: "host:port" ServerAddr string // ConnectTimeout 连接超时时间 ConnectTimeout time.Duration // MaxRetries 最大重试次数 MaxRetries int } // TCPPublisher 实现基于 TCP 的 watermill Publisher type TCPPublisher struct { config TCPPublisherConfig conn net.Conn logger logger.Logger closed bool closedMu sync.RWMutex closeChan chan struct{} } // NewTCPPublisher 创建一个新的 TCP Publisher. func NewTCPPublisher(config TCPPublisherConfig, logger logger.Logger) (*TCPPublisher, error) { if config.ServerAddr == "" { return nil, ErrServerAddrRequired } if config.ConnectTimeout == 0 { config.ConnectTimeout = defaultConnectTimeout } if config.MaxRetries == 0 { config.MaxRetries = defaultMaxRetries } p := &TCPPublisher{ config: config, logger: logger, closeChan: make(chan struct{}), } // 连接到服务器 if err := p.connect(); err != nil { return nil, err } // 不再接收 ACK/NACK,发送即成功模式 // go p.receiveAcks() // 已移除 return p, nil } // connect 连接到 TCP 服务器 func (p *TCPPublisher) connect() error { ctx, cancel := context.WithTimeout(context.Background(), p.config.ConnectTimeout) defer cancel() var d net.Dialer conn, err := d.DialContext(ctx, "tcp", p.config.ServerAddr) if err != nil { return fmt.Errorf("failed to connect to %s: %w", p.config.ServerAddr, err) } p.conn = conn p.logger.InfoContext(context.Background(), "Connected to TCP server", "addr", p.config.ServerAddr) return nil } // Publish 发布消息. func (p *TCPPublisher) Publish(topic string, messages ...*message.Message) error { p.closedMu.RLock() if p.closed { p.closedMu.RUnlock() return ErrPublisherClosed } p.closedMu.RUnlock() ctx := context.Background() // 使用 WaitGroup 和 errChan 来并发发送消息并收集错误 var wg sync.WaitGroup errs := make([]error, 0, len(messages)) var errMu sync.Mutex errChan := make(chan error, len(messages)) for _, msg := range messages { if msg == nil { continue } wg.Add(1) go func(m *message.Message) { defer wg.Done() if err := p.publishSingle(ctx, topic, m); err != nil { errChan <- err } }(msg) } // 等待所有消息发送完成 wg.Wait() close(errChan) // 检查是否有错误 for err := range errChan { errMu.Lock() errs = append(errs, err) errMu.Unlock() } if len(errs) > 0 { return fmt.Errorf("failed to publish %d messages: %w", len(errs), errors.Join(errs...)) } return nil } // publishSingle 发送单条消息,不等待 ACK func (p *TCPPublisher) publishSingle(ctx context.Context, topic string, msg *message.Message) error { tcpMsg := &TCPMessage{ Type: MessageTypeData, Topic: topic, UUID: msg.UUID, Payload: msg.Payload, } // 编码消息 data, err := EncodeTCPMessage(tcpMsg) if err != nil { return fmt.Errorf("failed to encode message: %w", err) } p.logger.DebugContext(ctx, "Sending message", "uuid", msg.UUID, "topic", topic) // 发送消息 if _, err := p.conn.Write(data); err != nil { return fmt.Errorf("failed to write message: %w", err) } p.logger.DebugContext(ctx, "Message sent successfully", "uuid", msg.UUID) return nil } // receiveAcks, shouldStopReceiving, handleDecodeError 方法已移除 // 不再接收 ACK/NACK,采用发送即成功模式以提高性能 // Close 关闭发布者 func (p *TCPPublisher) Close() error { p.closedMu.Lock() if p.closed { p.closedMu.Unlock() return nil } p.closed = true p.closedMu.Unlock() close(p.closeChan) if p.conn != nil { if err := p.conn.Close(); err != nil { return fmt.Errorf("failed to close connection: %w", err) } } p.logger.InfoContext(context.Background(), "TCP Publisher closed") return nil }