diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7cb38fbff..54fbb002c 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -25,8 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -// CreateConnection creates a gRPC client connection with the appropriate transport options -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +// CreateConnection creates a gRPC client connection with the appropriate transport options. +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -49,7 +50,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. connCtx, addr, transportOption, - WithCustomDialer(tlsEnabled), + WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index a0d6cee0b..96f347c64 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool) grpc.DialOption { +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go index e132c0098..b89ec3c21 100644 --- a/client/grpc/dialer_js.go +++ b/client/grpc/dialer_js.go @@ -7,6 +7,7 @@ import ( ) // WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. -func WithCustomDialer(tlsEnabled bool) grpc.DialOption { - return client.WithWebSocketDialer(tlsEnabled) +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled, component) } diff --git a/flow/client/client.go b/flow/client/client.go index 03a4accaf..318fcfe1e 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -23,6 +23,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" + "github.com/netbirdio/netbird/util/wsproxy" ) type GRPCClient struct { @@ -54,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(tlsEnabled), + nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ae9ac4a60..94c633fc6 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -259,7 +259,7 @@ func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Hand case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - case request.URL.Path == wsproxy.ProxyPath: + case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: wsProxy.Handler().ServeHTTP(writer, request) default: httpHandler.ServeHTTP(writer, request) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index f30e965be..076f2532b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/wsproxy" ) const ConnectTimeout = 10 * time.Second @@ -52,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5ca0c0282..31f3372c0 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/util/wsproxy" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +58,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/signal/cmd/run.go b/signal/cmd/run.go index e2a69a75b..696c44723 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -258,7 +258,7 @@ func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.URL.Path == wsproxy.ProxyPath: + case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent: wsProxy.Handler().ServeHTTP(w, r) default: grpcServer.ServeHTTP(w, r) diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go index 2caeed025..bd50f51b5 100644 --- a/util/wsproxy/client/dialer_js.go +++ b/util/wsproxy/client/dialer_js.go @@ -96,13 +96,14 @@ func (s stringAddr) Network() string { return "tcp" } func (s stringAddr) String() string { return string(s) } // WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. -func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithWebSocketDialer(tlsEnabled bool, component string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { scheme := "wss" if !tlsEnabled { scheme = "ws" } - wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) + wsURL := fmt.Sprintf("%s://%s%s%s", scheme, addr, wsproxy.ProxyPath, component) ws := js.Global().Get("WebSocket").New(wsURL) diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go index 8d117c7d9..a31c0fbc8 100644 --- a/util/wsproxy/constants.go +++ b/util/wsproxy/constants.go @@ -2,9 +2,16 @@ package wsproxy import "errors" -// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. +// ProxyPath is the base path where the WebSocket proxy is mounted on servers. const ProxyPath = "/ws-proxy" +// Component paths that are appended to ProxyPath +const ( + ManagementComponent = "/management" + SignalComponent = "/signal" + FlowComponent = "/flow" +) + // Common errors var ( ErrConnectionTimeout = errors.New("WebSocket connection timeout")