Files
go-trustlog/internal/grpcclient/loadbalancer.go
ryan d313449c5c refactor: 重构trustlog-sdk目录结构到trustlog/go-trustlog
- 将所有trustlog-sdk文件移动到trustlog/go-trustlog/目录
- 更新README中所有import路径从trustlog-sdk改为go-trustlog
- 更新cookiecutter配置文件中的项目名称
- 更新根目录.lefthook.yml以引用新位置的配置
- 添加go.sum文件到版本控制
- 删除过时的示例文件

这次重构与trustlog-server保持一致的目录结构,
为未来支持多语言SDK(Python、Java等)预留空间。
2025-12-22 13:37:57 +08:00

114 lines
2.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package grpcclient
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// ClientFactory 客户端工厂函数类型.
type ClientFactory[T any] func(grpc.ClientConnInterface) T
// ServerClient 封装单个服务器的连接.
type ServerClient[T any] struct {
addr string
conn *grpc.ClientConn
client T
}
// LoadBalancer 轮询负载均衡器(泛型版本).
type LoadBalancer[T any] struct {
servers []*ServerClient[T]
counter atomic.Uint64
mu sync.RWMutex
closed bool
}
// NewLoadBalancer 创建新的负载均衡器.
func NewLoadBalancer[T any](
addrs []string,
dialOpts []grpc.DialOption,
factory ClientFactory[T],
) (*LoadBalancer[T], error) {
if len(addrs) == 0 {
return nil, errors.New("at least one server address is required")
}
lb := &LoadBalancer[T]{
servers: make([]*ServerClient[T], 0, len(addrs)),
}
// 默认使用不安全的连接生产环境应使用TLS
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
}
opts = append(opts, dialOpts...)
// 连接所有服务器
for _, addr := range addrs {
conn, err := grpc.NewClient(addr, opts...)
if err != nil {
// 关闭已创建的连接
_ = lb.Close()
return nil, fmt.Errorf("failed to connect to server %s: %w", addr, err)
}
client := factory(conn)
lb.servers = append(lb.servers, &ServerClient[T]{
addr: addr,
conn: conn,
client: client,
})
}
return lb, nil
}
// Next 使用轮询算法获取下一个客户端.
func (lb *LoadBalancer[T]) Next() T {
lb.mu.RLock()
defer lb.mu.RUnlock()
if len(lb.servers) == 0 || lb.closed {
var zero T
return zero
}
// 原子递增计数器并取模
idx := lb.counter.Add(1) % uint64(len(lb.servers))
return lb.servers[idx].client
}
// Close 关闭所有连接.
func (lb *LoadBalancer[T]) Close() error {
lb.mu.Lock()
defer lb.mu.Unlock()
// 如果已经关闭,直接返回
if lb.closed {
return nil
}
var lastErr error
for _, server := range lb.servers {
if err := server.conn.Close(); err != nil {
lastErr = err
}
}
// 标记为已关闭
lb.closed = true
return lastErr
}
// ServerCount 返回服务器数量.
func (lb *LoadBalancer[T]) ServerCount() int {
lb.mu.RLock()
defer lb.mu.RUnlock()
return len(lb.servers)
}