package grpcclient import ( "errors" "fmt" "sync" "sync/atomic" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) // ClientFactory 客户端工厂函数类型. type ClientFactory[T any] func(grpc.ClientConnInterface) T // ServerClient 封装单个服务器的连接. type ServerClient[T any] struct { addr string conn *grpc.ClientConn client T } // LoadBalancer 轮询负载均衡器(泛型版本). type LoadBalancer[T any] struct { servers []*ServerClient[T] counter atomic.Uint64 mu sync.RWMutex closed bool } // NewLoadBalancer 创建新的负载均衡器. func NewLoadBalancer[T any]( addrs []string, dialOpts []grpc.DialOption, factory ClientFactory[T], ) (*LoadBalancer[T], error) { if len(addrs) == 0 { return nil, errors.New("at least one server address is required") } lb := &LoadBalancer[T]{ servers: make([]*ServerClient[T], 0, len(addrs)), } // 默认使用不安全的连接(生产环境应使用TLS) opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), } opts = append(opts, dialOpts...) // 连接所有服务器 for _, addr := range addrs { conn, err := grpc.NewClient(addr, opts...) if err != nil { // 关闭已创建的连接 _ = lb.Close() return nil, fmt.Errorf("failed to connect to server %s: %w", addr, err) } client := factory(conn) lb.servers = append(lb.servers, &ServerClient[T]{ addr: addr, conn: conn, client: client, }) } return lb, nil } // Next 使用轮询算法获取下一个客户端. func (lb *LoadBalancer[T]) Next() T { lb.mu.RLock() defer lb.mu.RUnlock() if len(lb.servers) == 0 || lb.closed { var zero T return zero } // 原子递增计数器并取模 idx := lb.counter.Add(1) % uint64(len(lb.servers)) return lb.servers[idx].client } // Close 关闭所有连接. func (lb *LoadBalancer[T]) Close() error { lb.mu.Lock() defer lb.mu.Unlock() // 如果已经关闭,直接返回 if lb.closed { return nil } var lastErr error for _, server := range lb.servers { if err := server.conn.Close(); err != nil { lastErr = err } } // 标记为已关闭 lb.closed = true return lastErr } // ServerCount 返回服务器数量. func (lb *LoadBalancer[T]) ServerCount() int { lb.mu.RLock() defer lb.mu.RUnlock() return len(lb.servers) }