package grpcclient_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "go.yandata.net/iod/iod/go-trustlog/internal/grpcclient" ) // mockClient 用于测试的模拟客户端. type mockClient struct { ID string } func TestNewLoadBalancer(t *testing.T) { tests := []struct { name string addrs []string dialOpts []grpc.DialOption wantErr bool errMsg string }{ { name: "成功创建负载均衡器", addrs: []string{ "localhost:9090", "localhost:9091", }, dialOpts: []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), }, wantErr: false, }, { name: "没有地址应该失败", addrs: []string{}, dialOpts: nil, wantErr: true, errMsg: "at least one server address is required", }, { name: "nil地址列表应该失败", addrs: nil, dialOpts: nil, wantErr: true, errMsg: "at least one server address is required", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { lb, err := grpcclient.NewLoadBalancer( tt.addrs, tt.dialOpts, func(_ grpc.ClientConnInterface) *mockClient { return &mockClient{ID: "test"} }, ) if tt.wantErr { require.Error(t, err) if tt.errMsg != "" { assert.Contains(t, err.Error(), tt.errMsg) } assert.Nil(t, lb) } else { // 注意:这里会实际尝试连接,在测试环境下可能失败 // 实际使用时应该使用 mock 或 bufconn if err != nil { t.Skipf("Skipping test - cannot connect to servers: %v", err) return } require.NoError(t, err) require.NotNil(t, lb) assert.Equal(t, len(tt.addrs), lb.ServerCount()) // 清理 _ = lb.Close() } }) } } func TestLoadBalancer_Next(t *testing.T) { // 创建一个模拟的负载均衡器,不需要真实连接 t.Run("轮询算法测试", func(t *testing.T) { // 这个测试需要使用 bufconn 或其他 mock 方式 // 暂时跳过需要真实连接的测试 if testing.Short() { t.Skip("Skipping test that requires network connection") } addrs := []string{"localhost:9090", "localhost:9091", "localhost:9092"} lb, err := grpcclient.NewLoadBalancer( addrs, []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, func(_ grpc.ClientConnInterface) *mockClient { return &mockClient{ID: "test"} }, ) if err != nil { t.Skipf("Cannot create load balancer: %v", err) return } defer lb.Close() // 测试轮询:调用 Next() 多次应该轮询返回不同的客户端 clients := make([]*mockClient, 6) for i := range 6 { clients[i] = lb.Next() assert.NotNil(t, clients[i]) } }) } func TestLoadBalancer_Close(t *testing.T) { if testing.Short() { t.Skip("Skipping test that requires network connection") } addrs := []string{"localhost:9090"} lb, err := grpcclient.NewLoadBalancer( addrs, []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, func(_ grpc.ClientConnInterface) *mockClient { return &mockClient{ID: "test"} }, ) if err != nil { t.Skipf("Cannot create load balancer: %v", err) return } // 第一次关闭 err = lb.Close() require.NoError(t, err) // 再次关闭应该也不会报错 err = lb.Close() assert.NoError(t, err) } func TestLoadBalancer_ServerCount(t *testing.T) { if testing.Short() { t.Skip("Skipping test that requires network connection") } tests := []struct { name string addrs []string wantCount int }{ { name: "单服务器", addrs: []string{"localhost:9090"}, wantCount: 1, }, { name: "多服务器", addrs: []string{"localhost:9090", "localhost:9091", "localhost:9092"}, wantCount: 3, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { lb, err := grpcclient.NewLoadBalancer( tt.addrs, []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, func(_ grpc.ClientConnInterface) *mockClient { return &mockClient{ID: "test"} }, ) if err != nil { t.Skipf("Cannot create load balancer: %v", err) return } defer lb.Close() assert.Equal(t, tt.wantCount, lb.ServerCount()) }) } }