From ab775089504d5254a5236c2a15d2a749476ad91f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 20 Mar 2026 00:33:50 +0800 Subject: [PATCH] [client] Add env var for management gRPC max receive message size (#5622) --- client/grpc/dialer.go | 11 ++-- shared/management/client/grpc.go | 36 +++++++++- shared/management/client/grpc_test.go | 95 +++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 6 deletions(-) create mode 100644 shared/management/client/grpc_test.go diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 54966b50e..9a6bc0670 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff { // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) // for js, the outer websocket layer takes care of tls if tlsEnabled && runtime.GOOS != "js" { @@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - conn, err := grpc.DialContext( - connCtx, - addr, + opts := []grpc.DialOption{ transportOption, WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), @@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone Time: 30 * time.Second, Timeout: 10 * time.Second, }), - ) + } + opts = append(opts, extraOpts...) + + conn, err := grpc.DialContext(connCtx, addr, opts...) if err != nil { return nil, fmt.Errorf("dial context: %w", err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 333f0bf00..e95db0089 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "io" + "os" + "strconv" "sync" "time" @@ -29,6 +31,10 @@ import ( const ConnectTimeout = 10 * time.Second const ( + // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) + // for the management client connection. Value is in bytes. + EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE" + errMsgMgmtPublicKey = "failed getting Management Service public key: %s" errMsgNoMgmtConnection = "no connection to management" ) @@ -66,13 +72,41 @@ type ExposeResponse struct { PortAutoAssigned bool } +// MaxRecvMsgSize returns the configured max gRPC receive message size from +// the environment, or 0 if unset (which uses the gRPC default of 4 MB). +func MaxRecvMsgSize() int { + val := os.Getenv(EnvMaxRecvMsgSize) + if val == "" { + return 0 + } + + size, err := strconv.Atoi(val) + if err != nil { + log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err) + return 0 + } + + if size <= 0 { + log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size) + return 0 + } + + return size +} + // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { var conn *grpc.ClientConn + var extraOpts []grpc.DialOption + if maxSize := MaxRecvMsgSize(); maxSize > 0 { + extraOpts = append(extraOpts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxSize))) + log.Infof("management gRPC max receive message size set to %d bytes", maxSize) + } + operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent, extraOpts...) if err != nil { return fmt.Errorf("create connection: %w", err) } diff --git a/shared/management/client/grpc_test.go b/shared/management/client/grpc_test.go new file mode 100644 index 000000000..462cc43af --- /dev/null +++ b/shared/management/client/grpc_test.go @@ -0,0 +1,95 @@ +package client + +import ( + "context" + "net" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestMaxRecvMsgSize(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {name: "unset returns 0", envValue: "", expected: 0}, + {name: "valid value", envValue: "10485760", expected: 10485760}, + {name: "non-numeric returns 0", envValue: "abc", expected: 0}, + {name: "negative returns 0", envValue: "-1", expected: 0}, + {name: "zero returns 0", envValue: "0", expected: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(EnvMaxRecvMsgSize, tt.envValue) + if tt.envValue == "" { + os.Unsetenv(EnvMaxRecvMsgSize) + } + assert.Equal(t, tt.expected, MaxRecvMsgSize()) + }) + } +} + +// largeSyncServer implements just the Sync RPC, returning a response larger than the default 4MB limit. +type largeSyncServer struct { + mgmtProto.UnimplementedManagementServiceServer + responseSize int +} + +func (s *largeSyncServer) GetServerKey(_ context.Context, _ *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) { + // Return a response with a large WiretrusteeConfig to exceed the default limit. + padding := strings.Repeat("x", s.responseSize) + return &mgmtProto.ServerKeyResponse{ + Key: padding, + }, nil +} + +func TestMaxRecvMsgSizeIntegration(t *testing.T) { + const payloadSize = 5 * 1024 * 1024 // 5MB, exceeds 4MB default + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + srv := grpc.NewServer() + mgmtProto.RegisterManagementServiceServer(srv, &largeSyncServer{responseSize: payloadSize}) + go func() { _ = srv.Serve(lis) }() + t.Cleanup(srv.Stop) + + t.Run("default limit rejects large message", func(t *testing.T) { + conn, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer conn.Close() + + client := mgmtProto.NewManagementServiceClient(conn) + _, err = client.GetServerKey(context.Background(), &mgmtProto.Empty{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "received message larger than max") + }) + + t.Run("increased limit accepts large message", func(t *testing.T) { + conn, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(10*1024*1024)), + ) + require.NoError(t, err) + defer conn.Close() + + client := mgmtProto.NewManagementServiceClient(conn) + resp, err := client.GetServerKey(context.Background(), &mgmtProto.Empty{}) + require.NoError(t, err) + assert.Len(t, resp.Key, payloadSize) + }) +}