[client,signal,management] Adjust browser client ws proxy paths (#4565)

This commit is contained in:
Viktor Liu
2025-10-02 00:10:47 +02:00
committed by GitHub
parent b5daec3b51
commit 4d7e59f199
10 changed files with 27 additions and 14 deletions

View File

@@ -25,8 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// CreateConnection creates a gRPC client connection with the appropriate transport options
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
@@ -49,7 +50,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
connCtx,
addr,
transportOption,
WithCustomDialer(tlsEnabled),
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()

View File

@@ -7,6 +7,7 @@ import (
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled)
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
}

View File

@@ -23,6 +23,7 @@ import (
nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/flow/proto"
"github.com/netbirdio/netbird/util/embeddedroots"
"github.com/netbirdio/netbird/util/wsproxy"
)
type GRPCClient struct {
@@ -54,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
}
opts = append(opts,
nbgrpc.WithCustomDialer(tlsEnabled),
nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent),
grpc.WithIdleTimeout(interval*2),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

View File

@@ -259,7 +259,7 @@ func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Hand
case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")):
gRPCHandler.ServeHTTP(writer, request)
case request.URL.Path == wsproxy.ProxyPath:
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(writer, request)
default:
httpHandler.ServeHTTP(writer, request)

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/wsproxy"
)
const ConnectTimeout = 10 * time.Second
@@ -52,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util/wsproxy"
)
// ConnStateNotifier is a wrapper interface of the status recorder
@@ -57,7 +58,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err

View File

@@ -258,7 +258,7 @@ func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == wsproxy.ProxyPath:
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
wsProxy.Handler().ServeHTTP(w, r)
default:
grpcServer.ServeHTTP(w, r)

View File

@@ -96,13 +96,14 @@ func (s stringAddr) Network() string { return "tcp" }
func (s stringAddr) String() string { return string(s) }
// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments.
func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithWebSocketDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
scheme := "wss"
if !tlsEnabled {
scheme = "ws"
}
wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath)
wsURL := fmt.Sprintf("%s://%s%s%s", scheme, addr, wsproxy.ProxyPath, component)
ws := js.Global().Get("WebSocket").New(wsURL)

View File

@@ -2,9 +2,16 @@ package wsproxy
import "errors"
// ProxyPath is the standard path where the WebSocket proxy is mounted on servers.
// ProxyPath is the base path where the WebSocket proxy is mounted on servers.
const ProxyPath = "/ws-proxy"
// Component paths that are appended to ProxyPath
const (
ManagementComponent = "/management"
SignalComponent = "/signal"
FlowComponent = "/flow"
)
// Common errors
var (
ErrConnectionTimeout = errors.New("WebSocket connection timeout")