package helpers import ( "errors" "fmt" "io" ) // TLVReader 提供 TLV(Type-Length-Value)格式的顺序读取能力。 // 支持无需反序列化全部报文即可读取特定字段。 type TLVReader struct { r io.Reader br io.ByteReader } // NewTLVReader 创建新的 TLVReader。 func NewTLVReader(r io.Reader) *TLVReader { return &TLVReader{ r: r, br: newByteReader(r), } } // ReadField 读取下一个 TLV 字段。 // 返回字段的长度和值。 func (tr *TLVReader) ReadField() ([]byte, error) { length, err := readVarint(tr.br) if err != nil { return nil, fmt.Errorf("failed to read field length: %w", err) } if length == 0 { return nil, nil } value := make([]byte, length) if _, errRead := io.ReadFull(tr.r, value); errRead != nil { return nil, fmt.Errorf("failed to read field value: %w", errRead) } return value, nil } // ReadStringField 读取下一个 TLV 字段并转换为字符串。 func (tr *TLVReader) ReadStringField() (string, error) { data, err := tr.ReadField() if err != nil { return "", err } return string(data), nil } // TLVWriter 提供 TLV 格式的顺序写入能力。 type TLVWriter struct { w io.Writer } // NewTLVWriter 创建新的 TLVWriter。 func NewTLVWriter(w io.Writer) *TLVWriter { return &TLVWriter{w: w} } // WriteField 写入一个 TLV 字段。 func (tw *TLVWriter) WriteField(value []byte) error { if err := writeVarint(tw.w, uint64(len(value))); err != nil { return fmt.Errorf("failed to write field length: %w", err) } if len(value) > 0 { if _, err := tw.w.Write(value); err != nil { return fmt.Errorf("failed to write field value: %w", err) } } return nil } // WriteStringField 写入一个字符串 TLV 字段。 func (tw *TLVWriter) WriteStringField(value string) error { return tw.WriteField([]byte(value)) } // Varint 编码/解码函数 const ( // varintContinueBit 表示 varint 还有后续字节的标志位。 varintContinueBit = 0x80 // varintDataMask 用于提取 varint 数据位的掩码。 varintDataMask = 0x7f // varintMaxShift 表示 varint 最大的位移量,防止溢出。 varintMaxShift = 64 ) // writeVarint 写入变长整数(类似 Protobuf 的 varint 编码)。 // 将 uint64 编码为变长格式,节省存储空间。 // func writeVarint(w io.Writer, x uint64) error { var buf [10]byte n := 0 for x >= varintContinueBit { buf[n] = byte(x) | varintContinueBit x >>= 7 n++ } buf[n] = byte(x) _, err := w.Write(buf[:n+1]) return err } // readVarint 读取变长整数。 // 从字节流中解码 varint 格式的整数。 func readVarint(r io.ByteReader) (uint64, error) { var x uint64 var shift uint for { b, err := r.ReadByte() if err != nil { return 0, err } x |= uint64(b&varintDataMask) << shift if b&varintContinueBit == 0 { return x, nil } shift += 7 if shift >= varintMaxShift { return 0, errors.New("varint overflow") } } } // byteReader 为 io.Reader 实现 io.ByteReader 接口。 // 提供逐字节读取能力,用于 varint 解码。 type byteReader struct { r io.Reader b [1]byte } func newByteReader(r io.Reader) io.ByteReader { return &byteReader{r: r} } func (br *byteReader) ReadByte() (byte, error) { _, err := br.r.Read(br.b[:]) return br.b[0], err }