mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 00:54:01 -04:00
[client] Add env var for management gRPC max receive message size (#5622)
This commit is contained in:
@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
// for js, the outer websocket layer takes care of tls
|
// for js, the outer websocket layer takes care of tls
|
||||||
if tlsEnabled && runtime.GOOS != "js" {
|
if tlsEnabled && runtime.GOOS != "js" {
|
||||||
@@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
opts := []grpc.DialOption{
|
||||||
connCtx,
|
|
||||||
addr,
|
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
@@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
}
|
||||||
|
opts = append(opts, extraOpts...)
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(connCtx, addr, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dial context: %w", err)
|
return nil, fmt.Errorf("dial context: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -29,6 +31,10 @@ import (
|
|||||||
const ConnectTimeout = 10 * time.Second
|
const ConnectTimeout = 10 * time.Second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||||
|
// for the management client connection. Value is in bytes.
|
||||||
|
EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE"
|
||||||
|
|
||||||
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
|
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
|
||||||
errMsgNoMgmtConnection = "no connection to management"
|
errMsgNoMgmtConnection = "no connection to management"
|
||||||
)
|
)
|
||||||
@@ -66,13 +72,41 @@ type ExposeResponse struct {
|
|||||||
PortAutoAssigned bool
|
PortAutoAssigned bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaxRecvMsgSize returns the configured max gRPC receive message size from
|
||||||
|
// the environment, or 0 if unset (which uses the gRPC default of 4 MB).
|
||||||
|
func MaxRecvMsgSize() int {
|
||||||
|
val := os.Getenv(EnvMaxRecvMsgSize)
|
||||||
|
if val == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
size, err := strconv.Atoi(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if size <= 0 {
|
||||||
|
log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
// NewClient creates a new client to Management service
|
// NewClient creates a new client to Management service
|
||||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||||
var conn *grpc.ClientConn
|
var conn *grpc.ClientConn
|
||||||
|
|
||||||
|
var extraOpts []grpc.DialOption
|
||||||
|
if maxSize := MaxRecvMsgSize(); maxSize > 0 {
|
||||||
|
extraOpts = append(extraOpts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxSize)))
|
||||||
|
log.Infof("management gRPC max receive message size set to %d bytes", maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
var err error
|
var err error
|
||||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent, extraOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create connection: %w", err)
|
return fmt.Errorf("create connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
95
shared/management/client/grpc_test.go
Normal file
95
shared/management/client/grpc_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMaxRecvMsgSize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envValue string
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{name: "unset returns 0", envValue: "", expected: 0},
|
||||||
|
{name: "valid value", envValue: "10485760", expected: 10485760},
|
||||||
|
{name: "non-numeric returns 0", envValue: "abc", expected: 0},
|
||||||
|
{name: "negative returns 0", envValue: "-1", expected: 0},
|
||||||
|
{name: "zero returns 0", envValue: "0", expected: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv(EnvMaxRecvMsgSize, tt.envValue)
|
||||||
|
if tt.envValue == "" {
|
||||||
|
os.Unsetenv(EnvMaxRecvMsgSize)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expected, MaxRecvMsgSize())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// largeSyncServer implements just the Sync RPC, returning a response larger than the default 4MB limit.
|
||||||
|
type largeSyncServer struct {
|
||||||
|
mgmtProto.UnimplementedManagementServiceServer
|
||||||
|
responseSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *largeSyncServer) GetServerKey(_ context.Context, _ *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) {
|
||||||
|
// Return a response with a large WiretrusteeConfig to exceed the default limit.
|
||||||
|
padding := strings.Repeat("x", s.responseSize)
|
||||||
|
return &mgmtProto.ServerKeyResponse{
|
||||||
|
Key: padding,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxRecvMsgSizeIntegration(t *testing.T) {
|
||||||
|
const payloadSize = 5 * 1024 * 1024 // 5MB, exceeds 4MB default
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srv := grpc.NewServer()
|
||||||
|
mgmtProto.RegisterManagementServiceServer(srv, &largeSyncServer{responseSize: payloadSize})
|
||||||
|
go func() { _ = srv.Serve(lis) }()
|
||||||
|
t.Cleanup(srv.Stop)
|
||||||
|
|
||||||
|
t.Run("default limit rejects large message", func(t *testing.T) {
|
||||||
|
conn, err := grpc.NewClient(
|
||||||
|
lis.Addr().String(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := mgmtProto.NewManagementServiceClient(conn)
|
||||||
|
_, err = client.GetServerKey(context.Background(), &mgmtProto.Empty{})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "received message larger than max")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("increased limit accepts large message", func(t *testing.T) {
|
||||||
|
conn, err := grpc.NewClient(
|
||||||
|
lis.Addr().String(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(10*1024*1024)),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := mgmtProto.NewManagementServiceClient(conn)
|
||||||
|
resp, err := client.GetServerKey(context.Background(), &mgmtProto.Empty{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, resp.Key, payloadSize)
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user