From b5daec3b51ee01ea779f727c74a0aa394a7a3d5d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:10:11 +0200 Subject: [PATCH] [client,signal,management] Add browser client support (#4415) --- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/wasm-build-validation.yml | 67 +++++ .gitmodules | 0 .goreleaser.yaml | 17 ++ client/cmd/debug_js.go | 8 + client/cmd/testutil_test.go | 4 +- client/embed/embed.go | 60 +++- client/grpc/dialer.go | 42 +-- client/grpc/dialer_generic.go | 44 +++ client/grpc/dialer_js.go | 12 + client/iface/bind/error.go | 7 + client/iface/bind/ice_bind.go | 58 ++-- client/iface/bind/recv_msg.go | 6 + client/iface/bind/relay_bind.go | 125 ++++++++ client/iface/configurer/name.go | 2 +- client/iface/configurer/uapi.go | 2 +- client/iface/configurer/uapi_js.go | 23 ++ client/iface/device/device_netstack.go | 28 +- client/iface/device/device_netstack_test.go | 27 ++ client/iface/iface_destroy_js.go | 6 + client/iface/iface_new_android.go | 4 +- client/iface/iface_new_darwin.go | 2 +- client/iface/iface_new_freebsd.go | 4 +- client/iface/iface_new_ios.go | 2 +- client/iface/iface_new_js.go | 27 ++ client/iface/iface_new_linux.go | 4 +- client/iface/iface_new_windows.go | 2 +- client/iface/netstack/env.go | 2 + client/iface/netstack/env_js.go | 12 + client/iface/wgproxy/bind/proxy.go | 23 +- client/iface/wgproxy/factory_usp.go | 11 +- client/iface/wgproxy/proxy_linux_test.go | 2 +- client/iface/wgproxy/proxy_seed_test.go | 2 +- client/internal/dns/server_js.go | 5 + client/internal/dns/unclean_shutdown_js.go | 19 ++ client/internal/engine.go | 20 +- client/internal/engine_generic.go | 19 ++ client/internal/engine_js.go | 18 ++ client/internal/engine_test.go | 8 +- .../networkmonitor/check_change_js.go | 12 + .../routemanager/systemops/systemops_js.go | 48 ++++ .../systemops/systemops_nonlinux.go | 2 +- client/server/server_test.go | 17 +- client/ssh/client.go | 2 + client/ssh/login.go | 2 + client/ssh/server.go | 2 + client/ssh/server_mock.go | 2 + client/ssh/server_test.go | 2 + client/ssh/ssh_js.go | 137 +++++++++ client/ssh/util.go | 2 + client/system/info_js.go | 231 +++++++++++++++ client/wasm/cmd/main.go | 245 ++++++++++++++++ client/wasm/internal/http/http.go | 100 +++++++ client/wasm/internal/rdp/cert_validation.go | 96 +++++++ client/wasm/internal/rdp/rdcleanpath.go | 271 ++++++++++++++++++ .../wasm/internal/rdp/rdcleanpath_handlers.go | 251 ++++++++++++++++ client/wasm/internal/ssh/client.go | 213 ++++++++++++++ client/wasm/internal/ssh/handlers.go | 78 +++++ client/wasm/internal/ssh/key.go | 50 ++++ encryption/route53.go | 2 + flow/client/client.go | 5 +- go.mod | 2 +- go.sum | 4 +- management/internals/server/controllers.go | 8 +- management/internals/server/modules.go | 4 + management/internals/server/server.go | 32 ++- management/server/account.go | 6 + management/server/account/manager.go | 4 +- management/server/account_test.go | 38 +-- management/server/dns_test.go | 4 +- management/server/grpcserver.go | 5 +- .../http/handlers/peers/peers_handler.go | 83 ++++++ management/server/management_proto_test.go | 3 +- management/server/management_test.go | 3 +- management/server/mock_server/account_mock.go | 12 +- management/server/nameserver_test.go | 4 +- .../server/networks/resources/manager.go | 4 +- management/server/peer.go | 58 +++- management/server/peer/peer.go | 12 + management/server/peer_test.go | 74 ++--- .../server/peers/ephemeral/interface.go | 14 + .../ephemeral/manager}/ephemeral.go | 2 +- .../ephemeral/manager}/ephemeral_test.go | 59 +++- management/server/policy.go | 12 + management/server/store/sql_store.go | 19 ++ management/server/store/store.go | 1 + management/server/types/account.go | 32 ++- management/server/types/policy.go | 87 ++++++ management/server/types/resource.go | 13 +- management/server/user_test.go | 4 +- shared/management/client/client_test.go | 4 +- shared/management/http/api/openapi.yml | 81 +++++- shared/management/http/api/types.gen.go | 28 ++ shared/management/proto/management.pb.go | 2 +- shared/relay/client/client.go | 12 +- shared/relay/client/dialer/ws/conn.go | 3 +- .../client/dialer/ws/dialopts_generic.go | 11 + shared/relay/client/dialer/ws/dialopts_js.go | 10 + shared/relay/client/dialer/ws/ws.go | 4 +- shared/relay/client/dialers_generic.go | 19 ++ shared/relay/client/dialers_js.go | 13 + signal/cmd/run.go | 52 +++- util/util_js.go | 8 + util/wsproxy/client/dialer_js.go | 171 +++++++++++ util/wsproxy/constants.go | 13 + util/wsproxy/server/metrics.go | 118 ++++++++ util/wsproxy/server/proxy.go | 227 +++++++++++++++ 107 files changed, 3591 insertions(+), 284 deletions(-) create mode 100644 .github/workflows/wasm-build-validation.yml create mode 100644 .gitmodules create mode 100644 client/cmd/debug_js.go create mode 100644 client/grpc/dialer_generic.go create mode 100644 client/grpc/dialer_js.go create mode 100644 client/iface/bind/error.go create mode 100644 client/iface/bind/recv_msg.go create mode 100644 client/iface/bind/relay_bind.go create mode 100644 client/iface/configurer/uapi_js.go create mode 100644 client/iface/device/device_netstack_test.go create mode 100644 client/iface/iface_destroy_js.go create mode 100644 client/iface/iface_new_js.go create mode 100644 client/iface/netstack/env_js.go create mode 100644 client/internal/dns/server_js.go create mode 100644 client/internal/dns/unclean_shutdown_js.go create mode 100644 client/internal/engine_generic.go create mode 100644 client/internal/engine_js.go create mode 100644 client/internal/networkmonitor/check_change_js.go create mode 100644 client/internal/routemanager/systemops/systemops_js.go create mode 100644 client/ssh/ssh_js.go create mode 100644 client/system/info_js.go create mode 100644 client/wasm/cmd/main.go create mode 100644 client/wasm/internal/http/http.go create mode 100644 client/wasm/internal/rdp/cert_validation.go create mode 100644 client/wasm/internal/rdp/rdcleanpath.go create mode 100644 client/wasm/internal/rdp/rdcleanpath_handlers.go create mode 100644 client/wasm/internal/ssh/client.go create mode 100644 client/wasm/internal/ssh/handlers.go create mode 100644 client/wasm/internal/ssh/key.go create mode 100644 management/server/peers/ephemeral/interface.go rename management/server/{ => peers/ephemeral/manager}/ephemeral.go (99%) rename management/server/{ => peers/ephemeral/manager}/ephemeral_test.go (75%) create mode 100644 shared/relay/client/dialer/ws/dialopts_generic.go create mode 100644 shared/relay/client/dialer/ws/dialopts_js.go create mode 100644 shared/relay/client/dialers_generic.go create mode 100644 shared/relay/client/dialers_js.go create mode 100644 util/util_js.go create mode 100644 util/wsproxy/client/dialer_js.go create mode 100644 util/wsproxy/constants.go create mode 100644 util/wsproxy/server/metrics.go create mode 100644 util/wsproxy/server/proxy.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 7e6583cc6..2845b05a5 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros skip: go.mod,go.sum golangci: strategy: diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml new file mode 100644 index 000000000..e4ac799bc --- /dev/null +++ b/.github/workflows/wasm-build-validation.yml @@ -0,0 +1,67 @@ +name: Wasm + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + js_lint: + name: "JS / Lint" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + - name: Install golangci-lint + uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc + with: + version: latest + install-mode: binary + skip-cache: true + skip-pkg-cache: true + skip-build-cache: true + - name: Run golangci-lint for WASM + run: | + GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/... + continue-on-error: true + + js_build: + name: "JS / Build" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Build Wasm client + run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd + env: + CGO_ENABLED: 0 + - name: Check Wasm build size + run: | + echo "Wasm build size:" + ls -lh netbird.wasm + + SIZE=$(stat -c%s netbird.wasm) + SIZE_MB=$((SIZE / 1024 / 1024)) + + echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" + + if [ ${SIZE} -gt 52428800 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + exit 1 + fi + diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..e69de29bb diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 59a95c89a..952e946dc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -2,6 +2,18 @@ version: 2 project_name: netbird builds: + - id: netbird-wasm + dir: client/wasm/cmd + binary: netbird + env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0] + goos: + - js + goarch: + - wasm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird dir: client binary: netbird @@ -115,6 +127,11 @@ archives: - builds: - netbird - netbird-static + - id: netbird-wasm + builds: + - netbird-wasm + name_template: "{{ .ProjectName }}_{{ .Version }}" + format: binary nfpms: - maintainer: Netbird diff --git a/client/cmd/debug_js.go b/client/cmd/debug_js.go new file mode 100644 index 000000000..d06fb8efc --- /dev/null +++ b/client/cmd/debug_js.go @@ -0,0 +1,8 @@ +package cmd + +import "context" + +// SetupDebugHandler is a no-op for WASM +func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) { + // Debug handler not needed for WASM +} diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 99ccb1539..bd3209605 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" @@ -20,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 0bfc7a37c..e918235ed 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -23,23 +23,29 @@ import ( var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientNotStarted = errors.New("client not started") +var ErrConfigNotInitialized = errors.New("config not initialized") -// Client manages a netbird embedded client instance +// Client manages a netbird embedded client instance. type Client struct { deviceName string config *profilemanager.Config mu sync.Mutex cancel context.CancelFunc setupKey string + jwtToken string connect *internal.ConnectClient } -// Options configures a new Client +// Options configures a new Client. type Options struct { // DeviceName is this peer's name in the network DeviceName string // SetupKey is used for authentication SetupKey string + // JWTToken is used for JWT-based authentication + JWTToken string + // PrivateKey is used for direct private key authentication + PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string // PreSharedKey is the pre-shared key for the WireGuard interface @@ -58,8 +64,35 @@ type Options struct { DisableClientRoutes bool } -// New creates a new netbird embedded client +// validateCredentials checks that exactly one credential type is provided +func (opts *Options) validateCredentials() error { + credentialsProvided := 0 + if opts.SetupKey != "" { + credentialsProvided++ + } + if opts.JWTToken != "" { + credentialsProvided++ + } + if opts.PrivateKey != "" { + credentialsProvided++ + } + + if credentialsProvided == 0 { + return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided") + } + if credentialsProvided > 1 { + return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified") + } + + return nil +} + +// New creates a new netbird embedded client. func New(opts Options) (*Client, error) { + if err := opts.validateCredentials(); err != nil { + return nil, err + } + if opts.LogOutput != nil { logrus.SetOutput(opts.LogOutput) } @@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) { return nil, fmt.Errorf("create config: %w", err) } + if opts.PrivateKey != "" { + config.PrivateKey = opts.PrivateKey + } + return &Client{ deviceName: opts.DeviceName, setupKey: opts.SetupKey, + jwtToken: opts.JWTToken, config: config, }, nil } @@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error { ctx := internal.CtxInitState(context.Background()) // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { + if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } @@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error { } } +// GetConfig returns a copy of the internal client config. +func (c *Client) GetConfig() (profilemanager.Config, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.config == nil { + return profilemanager.Config{}, ErrConfigNotInitialized + } + return *c.config, nil +} + // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { @@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } -// ListenTCP listens on the given address in the netbird network +// ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { nsnet, addr, err := c.getNet() @@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { return nsnet.ListenTCP(tcpAddr) } -// ListenUDP listens on the given address in the netbird network +// ListenUDP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenUDP(address string) (net.PacketConn, error) { nsnet, addr, err := c.getNet() diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 69e3f088c..7cb38fbff 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,15 +4,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "fmt" - "net" - "os/user" "runtime" "time" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -20,37 +14,10 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - nbnet "github.com/netbirdio/netbird/client/net" - "github.com/netbirdio/netbird/util/embeddedroots" ) -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - if runtime.GOOS == "linux" { - currentUser, err := user.Current() - if err != nil { - return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) - } - - // the custom dialer requires root permissions which are not required for use cases run as non-root - if currentUser.Uid != "0" { - log.Debug("Not running as root, using standard dialer") - dialer := &net.Dialer{} - return dialer.DialContext(ctx, "tcp", addr) - } - } - - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) - } - return conn, nil - }) -} - -// grpcDialBackoff is the backoff mechanism for the grpc calls +// Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() b.MaxElapsedTime = 10 * time.Second @@ -58,6 +25,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// CreateConnection creates a gRPC client connection with the appropriate transport options func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { @@ -68,7 +36,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - RootCAs: certPool, + // for js, outer websocket layer takes care of tls verification via WithCustomDialer + InsecureSkipVerify: runtime.GOOS == "js", + RootCAs: certPool, })) } @@ -79,7 +49,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. connCtx, addr, transportOption, - WithCustomDialer(), + WithCustomDialer(tlsEnabled), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go new file mode 100644 index 000000000..a0d6cee0b --- /dev/null +++ b/client/grpc/dialer_generic.go @@ -0,0 +1,44 @@ +//go:build !js + +package grpc + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func WithCustomDialer(tlsEnabled bool) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + if runtime.GOOS == "linux" { + currentUser, err := user.Current() + if err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) + } + + // the custom dialer requires root permissions which are not required for use cases run as non-root + if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) + } + } + + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) + } + return conn, nil + }) +} diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go new file mode 100644 index 000000000..e132c0098 --- /dev/null +++ b/client/grpc/dialer_js.go @@ -0,0 +1,12 @@ +package grpc + +import ( + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy/client" +) + +// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. +func WithCustomDialer(tlsEnabled bool) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled) +} diff --git a/client/iface/bind/error.go b/client/iface/bind/error.go new file mode 100644 index 000000000..db7c23144 --- /dev/null +++ b/client/iface/bind/error.go @@ -0,0 +1,7 @@ +package bind + +import "fmt" + +var ( + ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM") +) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index ef630b9d0..dfb22ecde 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,3 +1,5 @@ +//go:build !js + package bind import ( @@ -21,11 +23,6 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -type RecvMessage struct { - Endpoint *Endpoint - Buffer []byte -} - type receiverCreator struct { iceBind *ICEBind } @@ -43,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - recvChan chan RecvMessage transportNet transport.Net filterFn udpmux.FilterFn - endpoints map[netip.Addr]net.Conn - endpointsMu sync.Mutex + address wgaddr.Address + mtu uint16 + + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + recvChan chan recvMessage // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // new closed channel. With the closedChanMu we can safely close the channel and create a new one - closedChan chan struct{} - closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. - closed bool - - muUDPMux sync.Mutex - udpMux *udpmux.UniversalUDPMuxDefault - address wgaddr.Address - mtu uint16 + closedChan chan struct{} + closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. + closed bool activityRecorder *ActivityRecorder + + muUDPMux sync.Mutex + udpMux *udpmux.UniversalUDPMuxDefault } func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, + address: address, + mtu: mtu, endpoints: make(map[netip.Addr]net.Conn), + recvChan: make(chan recvMessage, 1), closedChan: make(chan struct{}), closed: true, - mtu: mtu, - address: address, activityRecorder: NewActivityRecorder(), } @@ -84,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg return ib } -func (s *ICEBind) MTU() uint16 { - return s.mtu -} - func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { s.closed = false s.closedChanMu.Lock() @@ -140,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) { delete(b.endpoints, fakeIP) } +func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-b.closedChan: + return + case <-ctx.Done(): + return + case b.recvChan <- recvMessage{ep, buf}: + } +} + func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { b.endpointsMu.Lock() conn, ok := b.endpoints[ep.DstIP()] @@ -156,14 +160,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { - select { - case <-ctx.Done(): - return - case b.recvChan <- msg: - } -} - func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() diff --git a/client/iface/bind/recv_msg.go b/client/iface/bind/recv_msg.go new file mode 100644 index 000000000..65baffaac --- /dev/null +++ b/client/iface/bind/recv_msg.go @@ -0,0 +1,6 @@ +package bind + +type recvMessage struct { + Endpoint *Endpoint + Buffer []byte +} diff --git a/client/iface/bind/relay_bind.go b/client/iface/bind/relay_bind.go new file mode 100644 index 000000000..4c179d6a5 --- /dev/null +++ b/client/iface/bind/relay_bind.go @@ -0,0 +1,125 @@ +package bind + +import ( + "context" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" + + "github.com/netbirdio/netbird/client/iface/udpmux" +) + +// RelayBindJS is a conn.Bind implementation for WebAssembly environments. +// Do not limit to build only js, because we want to be able to run tests +type RelayBindJS struct { + *conn.StdNetBind + + recvChan chan recvMessage + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + activityRecorder *ActivityRecorder + ctx context.Context + cancel context.CancelFunc +} + +func NewRelayBindJS() *RelayBindJS { + return &RelayBindJS{ + recvChan: make(chan recvMessage, 100), + endpoints: make(map[netip.Addr]net.Conn), + activityRecorder: NewActivityRecorder(), + } +} + +// Open creates a receive function for handling relay packets in WASM. +func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { + log.Debugf("Open: creating receive function for port %d", uport) + + s.ctx, s.cancel = context.WithCancel(context.Background()) + + receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + select { + case <-s.ctx.Done(): + return 0, net.ErrClosed + case msg, ok := <-s.recvChan: + if !ok { + return 0, net.ErrClosed + } + copy(bufs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = conn.Endpoint(msg.Endpoint) + return 1, nil + } + } + + log.Debugf("Open: receive function created, returning port %d", uport) + return []conn.ReceiveFunc{receiveFn}, uport, nil +} + +func (s *RelayBindJS) Close() error { + if s.cancel == nil { + return nil + } + log.Debugf("close RelayBindJS") + s.cancel() + return nil +} + +func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-s.ctx.Done(): + return + case <-ctx.Done(): + return + case s.recvChan <- recvMessage{ep, buf}: + } +} + +// Send forwards packets through the relay connection for WASM. +func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error { + if ep == nil { + return nil + } + + fakeIP := ep.DstIP() + + s.endpointsMu.Lock() + relayConn, ok := s.endpoints[fakeIP] + s.endpointsMu.Unlock() + + if !ok { + return nil + } + + for _, buf := range bufs { + if _, err := relayConn.Write(buf); err != nil { + return err + } + } + + return nil +} + +func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + b.endpointsMu.Lock() + b.endpoints[fakeIP] = conn + b.endpointsMu.Unlock() +} + +func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) { + s.endpointsMu.Lock() + defer s.endpointsMu.Unlock() + + delete(s.endpoints, fakeIP) +} + +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, ErrUDPMUXNotSupported +} + +func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder { + return s.activityRecorder +} diff --git a/client/iface/configurer/name.go b/client/iface/configurer/name.go index 3b9abc0e8..a8469e0b4 100644 --- a/client/iface/configurer/name.go +++ b/client/iface/configurer/name.go @@ -1,4 +1,4 @@ -//go:build linux || windows || freebsd +//go:build linux || windows || freebsd || js || wasip1 package configurer diff --git a/client/iface/configurer/uapi.go b/client/iface/configurer/uapi.go index 4801841de..f85c7852a 100644 --- a/client/iface/configurer/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build !windows && !js package configurer diff --git a/client/iface/configurer/uapi_js.go b/client/iface/configurer/uapi_js.go new file mode 100644 index 000000000..d0188eb35 --- /dev/null +++ b/client/iface/configurer/uapi_js.go @@ -0,0 +1,23 @@ +package configurer + +import ( + "net" +) + +type noopListener struct{} + +func (n *noopListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (n *noopListener) Close() error { + return nil +} + +func (n *noopListener) Addr() net.Addr { + return nil +} + +func openUAPI(deviceName string) (net.Listener, error) { + return &noopListener{}, nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index a6ef47027..e37321b68 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,9 +1,11 @@ package device import ( + "errors" "fmt" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -15,6 +17,12 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +type Bind interface { + conn.Bind + GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) + ActivityRecorder() *bind.ActivityRecorder +} + type TunNetstackDevice struct { name string address wgaddr.Address @@ -22,7 +30,7 @@ type TunNetstackDevice struct { key string mtu uint16 listenAddress string - iceBind *bind.ICEBind + bind Bind device *device.Device filteredDevice *FilteredDevice @@ -33,7 +41,7 @@ type TunNetstackDevice struct { net *netstack.Net } -func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: iceBind, + bind: bind, } } @@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { t.device = device.NewDevice( t.filteredDevice, - t.iceBind, + t.bind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() @@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return nil, err } - udpMux, err := t.iceBind.GetICEMux() - if err != nil { + udpMux, err := t.bind.GetICEMux() + if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) { return nil, err } - t.udpMux = udpMux + + if udpMux != nil { + t.udpMux = udpMux + } + log.Debugf("netstack device is ready to use") return udpMux, nil } diff --git a/client/iface/device/device_netstack_test.go b/client/iface/device/device_netstack_test.go new file mode 100644 index 000000000..52059602f --- /dev/null +++ b/client/iface/device/device_netstack_test.go @@ -0,0 +1,27 @@ +package device + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestNewNetstackDevice(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24") + + relayBind := bind.NewRelayBindJS() + nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr()) + + cfgr, err := nsTun.Create() + if err != nil { + t.Fatalf("failed to create netstack device: %v", err) + } + if cfgr == nil { + t.Fatal("expected non-nil configurer") + } +} diff --git a/client/iface/iface_destroy_js.go b/client/iface/iface_destroy_js.go new file mode 100644 index 000000000..b443273c3 --- /dev/null +++ b/client/iface/iface_destroy_js.go @@ -0,0 +1,6 @@ +package iface + +// Destroy is a no-op on WASM +func (w *WGIface) Destroy() error { + return nil +} diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 26952f48d..3b68f63f2 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index 7dd74d571..9f21ec950 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go index 86ed14ce1..a342bd579 100644 --- a/client/iface/iface_new_freebsd.go +++ b/client/iface/iface_new_freebsd.go @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -33,7 +33,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 06ccf0be1..5d6a32e39 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_js.go b/client/iface/iface_new_js.go new file mode 100644 index 000000000..ad913ab04 --- /dev/null +++ b/client/iface/iface_new_js.go @@ -0,0 +1,27 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + relayBind := bind.NewRelayBindJS() + + wgIface := &WGIface{ + tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), + } + + return wgIface, nil +} diff --git a/client/iface/iface_new_linux.go b/client/iface/iface_new_linux.go index 77fd30fae..d84035403 100644 --- a/client/iface/iface_new_linux.go +++ b/client/iface/iface_new_linux.go @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 349c5b33b..dfd9028e7 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index cdbf975b1..dd8cf29a3 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -1,3 +1,5 @@ +//go:build !js + package netstack import ( diff --git a/client/iface/netstack/env_js.go b/client/iface/netstack/env_js.go new file mode 100644 index 000000000..05c20f036 --- /dev/null +++ b/client/iface/netstack/env_js.go @@ -0,0 +1,12 @@ +package netstack + +const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE" + +// IsEnabled always returns true for js since it's the only mode available +func IsEnabled() bool { + return true +} + +func ListenAddr() string { + return "" +} diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index dbc694e91..eb585d8a2 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,15 +16,14 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) -type IceBind interface { - SetEndpoint(fakeIP netip.Addr, conn net.Conn) - RemoveEndpoint(fakeIP netip.Addr) - Recv(ctx context.Context, msg bind.RecvMessage) - MTU() uint16 +type Bind interface { + SetEndpoint(addr netip.Addr, conn net.Conn) + RemoveEndpoint(addr netip.Addr) + ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte) } type ProxyBind struct { - bind IceBind + bind Bind // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address wgRelayedEndpoint *bind.Endpoint @@ -40,13 +39,15 @@ type ProxyBind struct { isStarted bool closeListener *listener.CloseListener + mtu uint16 } -func NewProxyBind(bind IceBind) *ProxyBind { +func NewProxyBind(bind Bind, mtu uint16) *ProxyBind { p := &ProxyBind{ bind: bind, closeListener: listener.NewCloseListener(), pausedCond: sync.NewCond(&sync.Mutex{}), + mtu: mtu + bufsize.WGBufferOverhead, } return p @@ -174,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.mtu) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -190,11 +191,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { p.pausedCond.Wait() } - msg := bind.RecvMessage{ - Endpoint: p.wgCurrentUsed, - Buffer: buf[:n], - } - p.bind.Recv(ctx, msg) + p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n]) p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index 141b4c1f9..a1b1c34d7 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -3,24 +3,25 @@ package wgproxy import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" ) type USPFactory struct { - bind *bind.ICEBind + bind proxyBind.Bind + mtu uint16 } -func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { +func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory { log.Infof("WireGuard Proxy Factory will produce bind proxy") f := &USPFactory{ - bind: iceBind, + bind: bind, + mtu: mtu, } return f } func (w *USPFactory) GetProxy() Proxy { - return proxyBind.NewProxyBind(w.bind) + return proxyBind.NewProxyBind(w.bind, w.mtu) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 9526e91d2..dd24d1cdc 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -74,7 +74,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { pBind := proxyInstance{ name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), + proxy: bindproxy.NewProxyBind(iceBind, 0), endpointAddr: endpointAddress, closeFn: func() error { return nil }, } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go index 4d244f18a..ad375ccde 100644 --- a/client/iface/wgproxy/proxy_seed_test.go +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -30,7 +30,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { pBind := proxyInstance{ name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), + proxy: bindproxy.NewProxyBind(iceBind, 0), endpointAddr: endpointAddress, closeFn: func() error { return nil }, } diff --git a/client/internal/dns/server_js.go b/client/internal/dns/server_js.go new file mode 100644 index 000000000..a8bc35d09 --- /dev/null +++ b/client/internal/dns/server_js.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (hostManager, error) { + return &noopHostConfigurator{}, nil +} diff --git a/client/internal/dns/unclean_shutdown_js.go b/client/internal/dns/unclean_shutdown_js.go new file mode 100644 index 000000000..378ffc164 --- /dev/null +++ b/client/internal/dns/unclean_shutdown_js.go @@ -0,0 +1,19 @@ +package dns + +import ( + "context" +) + +type ShutdownState struct{} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} + +func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error { + return nil +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 828bc6e94..3fa0b58a8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -453,8 +453,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } - - // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) @@ -466,14 +464,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("initialize dns server: %w", err) } - iceCfg := icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - } + iceCfg := e.createICEConfig() e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr.Start(e.ctx) @@ -1347,14 +1338,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV Addr: e.getRosenpassAddr(), PermissiveMode: e.config.RosenpassPermissive, }, - ICEConfig: icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - }, + ICEConfig: e.createICEConfig(), } serviceDependencies := peer.ServiceDependencies{ diff --git a/client/internal/engine_generic.go b/client/internal/engine_generic.go new file mode 100644 index 000000000..34a75e45b --- /dev/null +++ b/client/internal/engine_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for non-WASM environments +func (e *Engine) createICEConfig() icemaker.Config { + return icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.SingleSocketUDPMux, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } +} diff --git a/client/internal/engine_js.go b/client/internal/engine_js.go new file mode 100644 index 000000000..dce3c57fb --- /dev/null +++ b/client/internal/engine_js.go @@ -0,0 +1,18 @@ +//go:build js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for WASM environment. +func (e *Engine) createICEConfig() icemaker.Config { + cfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + return cfg +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 4d2e81f43..344104405 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -27,6 +27,10 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" @@ -42,10 +46,8 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -1584,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/internal/networkmonitor/check_change_js.go b/client/internal/networkmonitor/check_change_js.go new file mode 100644 index 000000000..640cf7184 --- /dev/null +++ b/client/internal/networkmonitor/check_change_js.go @@ -0,0 +1,12 @@ +package networkmonitor + +import ( + "context" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + // No-op for WASM - network changes don't apply + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go new file mode 100644 index 000000000..808507fc9 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -0,0 +1,48 @@ +package systemops + +import ( + "errors" + "net" + "net/netip" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +var ErrRouteNotSupported = errors.New("route operations not supported on js") + +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func GetRoutesFromTable() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +func hasSeparateRouting() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +// GetDetailedRoutesFromTable returns empty routes for WASM. +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + return []DetailedRoute{}, nil +} + +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error { + return nil +} + +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error { + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 83b64e82b..905a7bc12 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -1,4 +1,4 @@ -//go:build !linux && !ios +//go:build !linux && !ios && !js package systemops diff --git a/client/server/server_test.go b/client/server/server_test.go index 755925003..e0a4805f6 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,23 +10,26 @@ import ( "time" "github.com/golang/mock/gomock" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -314,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/ssh/client.go b/client/ssh/client.go index 2dc70e8fc..afba347f8 100644 --- a/client/ssh/client.go +++ b/client/ssh/client.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/login.go b/client/ssh/login.go index d1d56ceb0..cb2615e55 100644 --- a/client/ssh/login.go +++ b/client/ssh/login.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server.go b/client/ssh/server.go index 1f2001d0f..8c5db2547 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go index cc080ffdb..76f43fd4e 100644 --- a/client/ssh/server_mock.go +++ b/client/ssh/server_mock.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import "context" diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go index 5caca1834..1f310c2bb 100644 --- a/client/ssh/server_test.go +++ b/client/ssh/server_test.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/ssh_js.go b/client/ssh/ssh_js.go new file mode 100644 index 000000000..8cea88702 --- /dev/null +++ b/client/ssh/ssh_js.go @@ -0,0 +1,137 @@ +package ssh + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "strings" + + "golang.org/x/crypto/ssh" +) + +var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment") + +// Server is a dummy SSH server interface for WASM. +type Server interface { + Start() error + Stop() error + EnableSSH(enabled bool) + AddAuthorizedKey(peer string, key string) error + RemoveAuthorizedKey(key string) +} + +type dummyServer struct{} + +func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { + return &dummyServer{}, nil +} + +func NewServer(addr string) Server { + return &dummyServer{} +} + +func (s *dummyServer) Start() error { + return ErrSSHNotSupported +} + +func (s *dummyServer) Stop() error { + return nil +} + +func (s *dummyServer) EnableSSH(enabled bool) { +} + +func (s *dummyServer) AddAuthorizedKey(peer string, key string) error { + return nil +} + +func (s *dummyServer) RemoveAuthorizedKey(key string) { +} + +type Client struct{} + +func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) { + return nil, ErrSSHNotSupported +} + +func (c *Client) Close() error { + return nil +} + +func (c *Client) Run(command []string) error { + return ErrSSHNotSupported +} + +type SessionRecorder struct{} + +func NewSessionRecorder() *SessionRecorder { + return &SessionRecorder{} +} + +func (r *SessionRecorder) Record(session string, data []byte) { +} + +func GetUserShell() string { + return "/bin/sh" +} + +func LookupUserInfo(username string) (string, string, error) { + return "", "", ErrSSHNotSupported +} + +const DefaultSSHPort = 44338 + +const ED25519 = "ed25519" + +func isRoot() bool { + return false +} + +func GeneratePrivateKey(keyType string) ([]byte, error) { + if keyType != ED25519 { + return nil, errors.New("only ED25519 keys are supported in WASM") + } + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, err + } + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + } + + pemBytes := pem.EncodeToMemory(pemBlock) + return pemBytes, nil +} + +func GeneratePublicKey(privateKey []byte) ([]byte, error) { + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + block, _ := pem.Decode(privateKey) + if block != nil { + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + signer, err = ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + } else { + return nil, err + } + } + + pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey()) + return []byte(strings.TrimSpace(string(pubKeyBytes))), nil +} diff --git a/client/ssh/util.go b/client/ssh/util.go index cf5f1396e..a54a609bc 100644 --- a/client/ssh/util.go +++ b/client/ssh/util.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/system/info_js.go b/client/system/info_js.go new file mode 100644 index 000000000..994d439a7 --- /dev/null +++ b/client/system/info_js.go @@ -0,0 +1,231 @@ +package system + +import ( + "context" + "runtime" + "strings" + "syscall/js" + + "github.com/netbirdio/netbird/version" +) + +// UpdateStaticInfoAsync is a no-op on JS as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + +// GetInfo retrieves system information for WASM environment +func GetInfo(_ context.Context) *Info { + info := &Info{ + GoOS: runtime.GOOS, + Kernel: runtime.GOARCH, + KernelVersion: runtime.GOARCH, + Platform: runtime.GOARCH, + OS: runtime.GOARCH, + Hostname: "wasm-client", + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + } + + collectBrowserInfo(info) + collectLocationInfo(info) + collectSystemInfo(info) + return info +} + +func collectBrowserInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + collectUserAgent(info, navigator) + collectPlatform(info, navigator) + collectCPUInfo(info, navigator) +} + +func collectUserAgent(info *Info, navigator js.Value) { + ua := navigator.Get("userAgent") + if ua.IsUndefined() { + return + } + + userAgent := ua.String() + os, osVersion := parseOSFromUserAgent(userAgent) + if os != "" { + info.OS = os + } + if osVersion != "" { + info.OSVersion = osVersion + } +} + +func collectPlatform(info *Info, navigator js.Value) { + // Try regular platform property + if plat := navigator.Get("platform"); !plat.IsUndefined() { + if platStr := plat.String(); platStr != "" { + info.Platform = platStr + } + } + + // Try newer userAgentData API for more accurate platform + userAgentData := navigator.Get("userAgentData") + if userAgentData.IsUndefined() { + return + } + + platformInfo := userAgentData.Get("platform") + if !platformInfo.IsUndefined() { + if platStr := platformInfo.String(); platStr != "" { + info.Platform = platStr + } + } +} + +func collectCPUInfo(info *Info, navigator js.Value) { + hardwareConcurrency := navigator.Get("hardwareConcurrency") + if !hardwareConcurrency.IsUndefined() { + info.CPUs = hardwareConcurrency.Int() + } +} + +func collectLocationInfo(info *Info) { + location := js.Global().Get("location") + if location.IsUndefined() { + return + } + + if host := location.Get("hostname"); !host.IsUndefined() { + hostnameStr := host.String() + if hostnameStr != "" && hostnameStr != "localhost" { + info.Hostname = hostnameStr + } + } +} + +func checkFileAndProcess(_ []string) ([]File, error) { + return []File{}, nil +} + +func collectSystemInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + if vendor := navigator.Get("vendor"); !vendor.IsUndefined() { + info.SystemManufacturer = vendor.String() + } + + if product := navigator.Get("product"); !product.IsUndefined() { + info.SystemProductName = product.String() + } + + if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() { + ua := userAgent.String() + info.Environment = detectEnvironmentFromUA(ua) + } +} + +func parseOSFromUserAgent(userAgent string) (string, string) { + if userAgent == "" { + return "", "" + } + + switch { + case strings.Contains(userAgent, "Windows NT"): + return parseWindowsVersion(userAgent) + case strings.Contains(userAgent, "Mac OS X"): + return parseMacOSVersion(userAgent) + case strings.Contains(userAgent, "FreeBSD"): + return "FreeBSD", "" + case strings.Contains(userAgent, "OpenBSD"): + return "OpenBSD", "" + case strings.Contains(userAgent, "NetBSD"): + return "NetBSD", "" + case strings.Contains(userAgent, "Linux"): + return parseLinuxVersion(userAgent) + case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"): + return parseiOSVersion(userAgent) + case strings.Contains(userAgent, "CrOS"): + return "ChromeOS", "" + default: + return "", "" + } +} + +func parseWindowsVersion(userAgent string) (string, string) { + switch { + case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"): + return "Windows", "10/11" + case strings.Contains(userAgent, "Windows NT 10.0"): + return "Windows", "10" + case strings.Contains(userAgent, "Windows NT 6.3"): + return "Windows", "8.1" + case strings.Contains(userAgent, "Windows NT 6.2"): + return "Windows", "8" + case strings.Contains(userAgent, "Windows NT 6.1"): + return "Windows", "7" + default: + return "Windows", "Unknown" + } +} + +func parseMacOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "Mac OS X ") + if idx == -1 { + return "macOS", "Unknown" + } + + versionStart := idx + len("Mac OS X ") + versionEnd := strings.Index(userAgent[versionStart:], ")") + if versionEnd <= 0 { + return "macOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "macOS", ver +} + +func parseLinuxVersion(userAgent string) (string, string) { + if strings.Contains(userAgent, "Android") { + return "Android", extractAndroidVersion(userAgent) + } + if strings.Contains(userAgent, "Ubuntu") { + return "Ubuntu", "" + } + return "Linux", "" +} + +func parseiOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "OS ") + if idx == -1 { + return "iOS", "Unknown" + } + + versionStart := idx + 3 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd <= 0 { + return "iOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "iOS", ver +} + +func extractAndroidVersion(userAgent string) string { + if idx := strings.Index(userAgent, "Android "); idx != -1 { + versionStart := idx + len("Android ") + versionEnd := strings.IndexAny(userAgent[versionStart:], ";)") + if versionEnd > 0 { + return userAgent[versionStart : versionStart+versionEnd] + } + } + return "Unknown" +} + +func detectEnvironmentFromUA(_ string) Environment { + return Environment{} +} diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go new file mode 100644 index 000000000..d542e2739 --- /dev/null +++ b/client/wasm/cmd/main.go @@ -0,0 +1,245 @@ +//go:build js + +package main + +import ( + "context" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + + netbird "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/client/wasm/internal/http" + "github.com/netbirdio/netbird/client/wasm/internal/rdp" + "github.com/netbirdio/netbird/client/wasm/internal/ssh" + "github.com/netbirdio/netbird/util" +) + +const ( + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" +) + +func main() { + js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor)) + + select {} +} + +func startClient(ctx context.Context, nbClient *netbird.Client) error { + log.Info("Starting NetBird client...") + if err := nbClient.Start(ctx); err != nil { + return err + } + log.Info("NetBird client started successfully") + return nil +} + +// parseClientOptions extracts NetBird options from JavaScript object +func parseClientOptions(jsOptions js.Value) (netbird.Options, error) { + options := netbird.Options{ + DeviceName: "dashboard-client", + LogLevel: defaultLogLevel, + } + + if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() { + options.JWTToken = jwtToken.String() + } + + if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() { + options.SetupKey = setupKey.String() + } + + if privateKey := jsOptions.Get("privateKey"); !privateKey.IsNull() && !privateKey.IsUndefined() { + options.PrivateKey = privateKey.String() + } + + if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() { + mgmtURLStr := mgmtURL.String() + if mgmtURLStr != "" { + options.ManagementURL = mgmtURLStr + } + } + + if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() { + options.LogLevel = logLevel.String() + } + + if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() { + options.DeviceName = deviceName.String() + } + + return options, nil +} + +// createStartMethod creates the start method for the client +func createStartMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout) + defer cancel() + + if err := startClient(ctx, client); err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createStopMethod creates the stop method for the client +func createStopMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout) + defer cancel() + + if err := client.Stop(ctx); err != nil { + log.Errorf("Error stopping client: %v", err) + reject.Invoke(js.ValueOf(err.Error())) + return + } + + log.Info("NetBird client stopped") + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createSSHMethod creates the SSH connection method +func createSSHMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + username := "root" + if len(args) > 2 && args[2].String() != "" { + username = args[2].String() + } + + return createPromise(func(resolve, reject js.Value) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username); err != nil { + reject.Invoke(err.Error()) + return + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + reject.Invoke(err.Error()) + return + } + + jsInterface := ssh.CreateJSInterface(sshClient) + resolve.Invoke(jsInterface) + }) + }) +} + +// createProxyRequestMethod creates the proxyRequest method +func createProxyRequestMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: request details required") + } + + request := args[0] + + return createPromise(func(resolve, reject js.Value) { + response, err := http.ProxyRequest(client, request) + if err != nil { + reject.Invoke(err.Error()) + return + } + resolve.Invoke(response) + }) + }) +} + +// createRDPProxyMethod creates the RDP proxy method +func createRDPProxyMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: hostname and port required") + } + + proxy := rdp.NewRDCleanPathProxy(client) + return proxy.CreateProxy(args[0].String(), args[1].String()) + }) +} + +// createPromise is a helper to create JavaScript promises +func createPromise(handler func(resolve, reject js.Value)) js.Value { + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + go handler(resolve, reject) + + return nil + })) +} + +// createClientObject wraps the NetBird client in a JavaScript object +func createClientObject(client *netbird.Client) js.Value { + obj := make(map[string]interface{}) + + obj["start"] = createStartMethod(client) + obj["stop"] = createStopMethod(client) + obj["createSSHConnection"] = createSSHMethod(client) + obj["proxyRequest"] = createProxyRequestMethod(client) + obj["createRDPProxy"] = createRDPProxyMethod(client) + + return js.ValueOf(obj) +} + +// netBirdClientConstructor acts as a JavaScript constructor function +func netBirdClientConstructor(this js.Value, args []js.Value) any { + return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + if len(args) < 1 { + reject.Invoke(js.ValueOf("Options object required")) + return nil + } + + go func() { + options, err := parseClientOptions(args[0]) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil { + log.Warnf("Failed to initialize logging: %v", err) + } + + log.Infof("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s", + options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL) + + client, err := netbird.New(options) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err))) + return + } + + clientObj := createClientObject(client) + log.Info("NetBird client created successfully") + resolve.Invoke(clientObj) + }() + + return nil + })) +} diff --git a/client/wasm/internal/http/http.go b/client/wasm/internal/http/http.go new file mode 100644 index 000000000..cddc9e681 --- /dev/null +++ b/client/wasm/internal/http/http.go @@ -0,0 +1,100 @@ +//go:build js + +package http + +import ( + "fmt" + "io" + log "github.com/sirupsen/logrus" + "net/http" + "strings" + "syscall/js" + "time" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + httpTimeout = 30 * time.Second + maxResponseSize = 1024 * 1024 // 1MB +) + +// performRequest executes an HTTP request through NetBird and returns the response and body +func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) { + httpClient := nbClient.NewHTTPClient() + httpClient.Timeout = httpTimeout + + req, err := http.NewRequest(method, url, strings.NewReader(string(body))) + if err != nil { + return nil, nil, fmt.Errorf("create request: %w", err) + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Errorf("failed to close response body: %v", err) + } + }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, nil, fmt.Errorf("read response: %w", err) + } + + return resp, respBody, nil +} + +// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object +func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) { + url := request.Get("url").String() + if url == "" { + return js.Undefined(), fmt.Errorf("URL is required") + } + + method := "GET" + if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() { + method = strings.ToUpper(methodVal.String()) + } + + var requestBody []byte + if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() { + requestBody = []byte(bodyVal.String()) + } + + requestHeaders := make(map[string]string) + if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject { + headerKeys := js.Global().Get("Object").Call("keys", headersVal) + for i := 0; i < headerKeys.Length(); i++ { + key := headerKeys.Index(i).String() + value := headersVal.Get(key).String() + requestHeaders[key] = value + } + } + + resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody) + if err != nil { + return js.Undefined(), err + } + + result := js.Global().Get("Object").New() + result.Set("status", resp.StatusCode) + result.Set("statusText", resp.Status) + result.Set("body", string(body)) + + headers := js.Global().Get("Object").New() + for key, values := range resp.Header { + if len(values) > 0 { + headers.Set(strings.ToLower(key), values[0]) + } + } + result.Set("headers", headers) + + return result, nil +} diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go new file mode 100644 index 000000000..4a23a4bc8 --- /dev/null +++ b/client/wasm/internal/rdp/cert_validation.go @@ -0,0 +1,96 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + certValidationTimeout = 60 * time.Second +) + +func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { + if !conn.wsHandlers.Get("onCertificateRequest").Truthy() { + return false, fmt.Errorf("certificate validation handler not configured") + } + + certInfo := js.Global().Get("Object").New() + certInfo.Set("ServerAddr", conn.destination) + + certArray := js.Global().Get("Array").New() + for i, certBytes := range certChain { + uint8Array := js.Global().Get("Uint8Array").New(len(certBytes)) + js.CopyBytesToJS(uint8Array, certBytes) + certArray.SetIndex(i, uint8Array) + } + certInfo.Set("ServerCertChain", certArray) + if len(certChain) > 0 { + cert, err := x509.ParseCertificate(certChain[0]) + if err == nil { + info := js.Global().Get("Object").New() + info.Set("subject", cert.Subject.String()) + info.Set("issuer", cert.Issuer.String()) + info.Set("validFrom", cert.NotBefore.Format(time.RFC3339)) + info.Set("validTo", cert.NotAfter.Format(time.RFC3339)) + info.Set("serialNumber", cert.SerialNumber.String()) + certInfo.Set("CertificateInfo", info) + } + } + + promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) + + resultChan := make(chan bool) + errorChan := make(chan error) + + promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result := args[0].Bool() + resultChan <- result + return nil + })).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + errorChan <- fmt.Errorf("certificate validation failed") + return nil + })) + + select { + case result := <-resultChan: + if result { + log.Info("Certificate accepted by user") + } else { + log.Info("Certificate rejected by user") + } + return result, nil + case err := <-errorChan: + return false, err + case <-time.After(certValidationTimeout): + return false, fmt.Errorf("certificate validation timeout") + } +} + +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, // We'll validate manually after handshake + VerifyConnection: func(cs tls.ConnectionState) error { + var certChain [][]byte + for _, cert := range cs.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + + accepted, err := p.validateCertificateWithJS(conn, certChain) + if err != nil { + return err + } + if !accepted { + return fmt.Errorf("certificate rejected by user") + } + + return nil + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go new file mode 100644 index 000000000..8062a05cc --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -0,0 +1,271 @@ +//go:build js + +package rdp + +import ( + "context" + "crypto/tls" + "encoding/asn1" + "fmt" + "io" + "net" + "sync" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +const ( + RDCleanPathVersion = 3390 + RDCleanPathProxyHost = "rdcleanpath.proxy.local" + RDCleanPathProxyScheme = "ws" +) + +type RDCleanPathPDU struct { + Version int64 `asn1:"tag:0,explicit"` + Error []byte `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathProxy struct { + nbClient interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) + } + activeConnections map[string]*proxyConnection + destinations map[string]string + mu sync.Mutex +} + +type proxyConnection struct { + id string + destination string + rdpConn net.Conn + tlsConn *tls.Conn + wsHandlers js.Value + ctx context.Context + cancel context.CancelFunc +} + +// NewRDCleanPathProxy creates a new RDCleanPath proxy +func NewRDCleanPathProxy(client interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) +}) *RDCleanPathProxy { + return &RDCleanPathProxy{ + nbClient: client, + activeConnections: make(map[string]*proxyConnection), + } +} + +// CreateProxy creates a new proxy endpoint for the given destination +func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { + destination := fmt.Sprintf("%s:%s", hostname, port) + + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any { + resolve := args[0] + + go func() { + proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) + + p.mu.Lock() + if p.destinations == nil { + p.destinations = make(map[string]string) + } + p.destinations[proxyID] = destination + p.mu.Unlock() + + proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) + + // Register the WebSocket handler for this specific proxy + js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: requires WebSocket argument") + } + + ws := args[0] + p.HandleWebSocketConnection(ws, proxyID) + return nil + })) + + log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) + resolve.Invoke(proxyURL) + }() + + return nil + })) +} + +// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP +func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) { + p.mu.Lock() + destination := p.destinations[proxyID] + p.mu.Unlock() + + if destination == "" { + log.Errorf("No destination found for proxy ID: %s", proxyID) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + // Don't defer cancel here - it will be called by cleanupConnection + + conn := &proxyConnection{ + id: proxyID, + destination: destination, + wsHandlers: ws, + ctx: ctx, + cancel: cancel, + } + + p.mu.Lock() + p.activeConnections[proxyID] = conn + p.mu.Unlock() + + p.setupWebSocketHandlers(ws, conn) + + log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID) +} + +func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { + ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return nil + } + + data := args[0] + go p.handleWebSocketMessage(conn, data) + return nil + })) + + ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { + log.Debug("WebSocket closed by JavaScript") + conn.cancel() + return nil + })) +} + +func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { + if !data.InstanceOf(js.Global().Get("Uint8Array")) { + return + } + + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + + if conn.rdpConn != nil || conn.tlsConn != nil { + p.forwardToRDP(conn, bytes) + return + } + + var pdu RDCleanPathPDU + _, err := asn1.Unmarshal(bytes, &pdu) + if err != nil { + log.Warnf("Failed to parse RDCleanPath PDU: %v", err) + n := len(bytes) + if n > 20 { + n = 20 + } + log.Warnf("First %d bytes: %x", n, bytes[:n]) + + if len(bytes) > 0 && bytes[0] == 0x03 { + log.Debug("Received raw RDP packet instead of RDCleanPath PDU") + go p.handleDirectRDP(conn, bytes) + return + } + return + } + + go p.processRDCleanPathPDU(conn, pdu) +} + +func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) { + var writer io.Writer + var connType string + + if conn.tlsConn != nil { + writer = conn.tlsConn + connType = "TLS" + } else if conn.rdpConn != nil { + writer = conn.rdpConn + connType = "TCP" + } else { + log.Error("No RDP connection available") + return + } + + if _, err := writer.Write(bytes); err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + } +} + +func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) { + defer p.cleanupConnection(conn) + + destination := conn.destination + log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + return + } + conn.rdpConn = rdpConn + + _, err = rdpConn.Write(firstPacket) + if err != nil { + log.Errorf("Failed to write first packet: %v", err) + return + } + + response := make([]byte, 1024) + n, err := rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + return + } + + p.sendToWebSocket(conn, response[:n]) + + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") +} + +func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { + log.Debugf("Cleaning up connection %s", conn.id) + conn.cancel() + if conn.tlsConn != nil { + log.Debug("Closing TLS connection") + if err := conn.tlsConn.Close(); err != nil { + log.Debugf("Error closing TLS connection: %v", err) + } + conn.tlsConn = nil + } + if conn.rdpConn != nil { + log.Debug("Closing TCP connection") + if err := conn.rdpConn.Close(); err != nil { + log.Debugf("Error closing TCP connection: %v", err) + } + conn.rdpConn = nil + } + p.mu.Lock() + delete(p.activeConnections, conn.id) + p.mu.Unlock() +} + +func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { + if conn.wsHandlers.Get("receiveFromGo").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer")) + } else if conn.wsHandlers.Get("send").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("send", uint8Array.Get("buffer")) + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go new file mode 100644 index 000000000..010efa5ea --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -0,0 +1,251 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "encoding/asn1" + "io" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) + + if pdu.Version != RDCleanPathVersion { + p.sendRDCleanPathError(conn, "Unsupported version") + return + } + + destination := conn.destination + if pdu.Destination != "" { + destination = pdu.Destination + } + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, "Connection failed") + p.cleanupConnection(conn) + return + } + conn.rdpConn = rdpConn + + // RDP always starts with X.224 negotiation, then determines if TLS is needed + // Modern RDP (since Windows Vista/2008) typically requires TLS + // The X.224 Connection Confirm response will indicate if TLS is required + // For now, we'll attempt TLS for all connections as it's the modern default + p.setupTLSConnection(conn, pdu) +} + +func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + var x224Response []byte + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + x224Response = response[:n] + log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) + } + + tlsConfig := p.getTLSConfigWithValidation(conn) + + tlsConn := tls.Client(conn.rdpConn, tlsConfig) + conn.tlsConn = tlsConn + + if err := tlsConn.Handshake(); err != nil { + log.Errorf("TLS handshake failed: %v", err) + p.sendRDCleanPathError(conn, "TLS handshake failed") + return + } + + log.Info("TLS handshake successful") + + // Certificate validation happens during handshake via VerifyConnection callback + var certChain [][]byte + connState := tlsConn.ConnectionState() + if len(connState.PeerCertificates) > 0 { + for _, cert := range connState.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + log.Debugf("Extracted %d certificates from TLS connection", len(certChain)) + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + ServerCertChain: certChain, + } + + if len(x224Response) > 0 { + responsePDU.X224ConnectionPDU = x224Response + } + + p.sendRDCleanPathPDU(conn, responsePDU) + + log.Debug("Starting TLS forwarding") + go p.forwardConnToWS(conn, conn.tlsConn, "TLS") + go p.forwardWSToConn(conn, conn.tlsConn, "TLS") + + <-conn.ctx.Done() + log.Debug("TLS connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + X224ConnectionPDU: response[:n], + ServerAddr: conn.destination, + } + + p.sendRDCleanPathPDU(conn, responsePDU) + } else { + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + } + p.sendRDCleanPathPDU(conn, responsePDU) + } + + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + + <-conn.ctx.Done() + log.Debug("TCP connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal RDCleanPath PDU: %v", err) + return + } + + log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data)) + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { + pdu := RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: []byte(errorMsg), + } + + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { + msgChan := make(chan []byte) + errChan := make(chan error) + + handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + if len(args) < 1 { + errChan <- io.EOF + return nil + } + + data := args[0] + if data.InstanceOf(js.Global().Get("Uint8Array")) { + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + msgChan <- bytes + } + return nil + }) + defer handler.Release() + + conn.wsHandlers.Set("onceGoMessage", handler) + + select { + case msg := <-msgChan: + return msg, nil + case err := <-errChan: + return nil, err + case <-conn.ctx.Done(): + return nil, conn.ctx.Err() + } +} + +func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) { + for { + if conn.ctx.Err() != nil { + return + } + + msg, err := p.readWebSocketMessage(conn) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from WebSocket: %v", err) + } + return + } + + _, err = dst.Write(msg) + if err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + return + } + } +} + +func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) { + buffer := make([]byte, 32*1024) + + for { + if conn.ctx.Err() != nil { + return + } + + n, err := src.Read(buffer) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from %s: %v", connType, err) + } + return + } + + if n > 0 { + p.sendToWebSocket(conn, buffer[:n]) + } + } +} diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go new file mode 100644 index 000000000..ca35525eb --- /dev/null +++ b/client/wasm/internal/ssh/client.go @@ -0,0 +1,213 @@ +//go:build js + +package ssh + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + sshDialTimeout = 30 * time.Second +) + +func closeWithLog(c io.Closer, resource string) { + if c != nil { + if err := c.Close(); err != nil { + logrus.Debugf("Failed to close %s: %v", resource, err) + } + } +} + +type Client struct { + nbClient *netbird.Client + sshClient *ssh.Client + session *ssh.Session + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + mu sync.RWMutex +} + +// NewClient creates a new SSH client +func NewClient(nbClient *netbird.Client) *Client { + return &Client{ + nbClient: nbClient, + } +} + +// Connect establishes an SSH connection through NetBird network +func (c *Client) Connect(host string, port int, username string) error { + addr := fmt.Sprintf("%s:%d", host, port) + logrus.Infof("SSH: Connecting to %s as %s", addr, username) + + var authMethods []ssh.AuthMethod + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return fmt.Errorf("get NetBird config: %w", err) + } + if nbConfig.SSHKey == "" { + return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") + } + + signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + pubKey := signer.PublicKey() + logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) + + authMethods = append(authMethods, ssh.PublicKeys(signer)) + + config := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: sshDialTimeout, + } + + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) + defer cancel() + + conn, err := c.nbClient.Dial(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("dial %s: %w", addr, err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + closeWithLog(conn, "connection after handshake error") + return fmt.Errorf("SSH handshake: %w", err) + } + + c.sshClient = ssh.NewClient(sshConn, chans, reqs) + logrus.Infof("SSH: Connected to %s", addr) + + return nil +} + +// StartSession starts an SSH session with PTY +func (c *Client) StartSession(cols, rows int) error { + if c.sshClient == nil { + return fmt.Errorf("SSH client not connected") + } + + session, err := c.sshClient.NewSession() + if err != nil { + return fmt.Errorf("create session: %w", err) + } + + c.mu.Lock() + defer c.mu.Unlock() + c.session = session + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.VINTR: 3, + ssh.VQUIT: 28, + ssh.VERASE: 127, + } + + if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { + closeWithLog(session, "session after PTY error") + return fmt.Errorf("PTY request: %w", err) + } + + c.stdin, err = session.StdinPipe() + if err != nil { + closeWithLog(session, "session after stdin error") + return fmt.Errorf("get stdin: %w", err) + } + + c.stdout, err = session.StdoutPipe() + if err != nil { + closeWithLog(session, "session after stdout error") + return fmt.Errorf("get stdout: %w", err) + } + + c.stderr, err = session.StderrPipe() + if err != nil { + closeWithLog(session, "session after stderr error") + return fmt.Errorf("get stderr: %w", err) + } + + if err := session.Shell(); err != nil { + closeWithLog(session, "session after shell error") + return fmt.Errorf("start shell: %w", err) + } + + logrus.Info("SSH: Session started with PTY") + return nil +} + +// Write sends data to the SSH session +func (c *Client) Write(data []byte) (int, error) { + c.mu.RLock() + stdin := c.stdin + c.mu.RUnlock() + + if stdin == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdin.Write(data) +} + +// Read reads data from the SSH session +func (c *Client) Read(buffer []byte) (int, error) { + c.mu.RLock() + stdout := c.stdout + c.mu.RUnlock() + + if stdout == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdout.Read(buffer) +} + +// Resize updates the terminal size +func (c *Client) Resize(cols, rows int) error { + c.mu.RLock() + session := c.session + c.mu.RUnlock() + + if session == nil { + return fmt.Errorf("SSH session not started") + } + return session.WindowChange(rows, cols) +} + +// Close closes the SSH connection +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.session != nil { + closeWithLog(c.session, "SSH session") + c.session = nil + } + if c.stdin != nil { + closeWithLog(c.stdin, "stdin") + c.stdin = nil + } + c.stdout = nil + c.stderr = nil + + if c.sshClient != nil { + err := c.sshClient.Close() + c.sshClient = nil + return err + } + return nil +} diff --git a/client/wasm/internal/ssh/handlers.go b/client/wasm/internal/ssh/handlers.go new file mode 100644 index 000000000..ea64eb0aa --- /dev/null +++ b/client/wasm/internal/ssh/handlers.go @@ -0,0 +1,78 @@ +//go:build js + +package ssh + +import ( + "io" + "syscall/js" + + "github.com/sirupsen/logrus" +) + +// CreateJSInterface creates a JavaScript interface for the SSH client +func CreateJSInterface(client *Client) js.Value { + jsInterface := js.Global().Get("Object").Call("create", js.Null()) + + jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf(false) + } + + data := args[0] + var bytes []byte + + if data.Type() == js.TypeString { + bytes = []byte(data.String()) + } else { + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes = make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + } + + _, err := client.Write(bytes) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf(false) + } + cols := args[0].Int() + rows := args[1].Int() + err := client.Resize(cols, rows) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { + client.Close() + return js.Undefined() + })) + + go readLoop(client, jsInterface) + + return jsInterface +} + +func readLoop(client *Client, jsInterface js.Value) { + buffer := make([]byte, 4096) + for { + n, err := client.Read(buffer) + if err != nil { + if err != io.EOF { + logrus.Debugf("SSH read error: %v", err) + } + if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() { + onclose.Invoke() + } + client.Close() + return + } + + if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() { + uint8Array := js.Global().Get("Uint8Array").New(n) + js.CopyBytesToJS(uint8Array, buffer[:n]) + ondata.Invoke(uint8Array) + } + } +} diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go new file mode 100644 index 000000000..4868ba30a --- /dev/null +++ b/client/wasm/internal/ssh/key.go @@ -0,0 +1,50 @@ +//go:build js + +package ssh + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "strings" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format +func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { + keyStr := string(keyPEM) + if !strings.Contains(keyStr, "-----BEGIN") { + keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") + } + + signer, err := ssh.ParsePrivateKey(keyPEM) + if err == nil { + return signer, nil + } + logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) + + block, _ := pem.Decode(keyPEM) + if block == nil { + keyPreview := string(keyPEM) + if len(keyPreview) > 100 { + keyPreview = keyPreview[:100] + } + return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) + if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(rsaKey) + } + if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(ecKey) + } + return nil, fmt.Errorf("parse private key: %w", err) + } + + return ssh.NewSignerFromKey(key) +} diff --git a/encryption/route53.go b/encryption/route53.go index 3c81ab103..48c7a3a1b 100644 --- a/encryption/route53.go +++ b/encryption/route53.go @@ -1,3 +1,5 @@ +//go:build !js + package encryption import ( diff --git a/flow/client/client.go b/flow/client/client.go index 603fd6882..03a4accaf 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl return nil, fmt.Errorf("parsing url: %w", err) } var opts []grpc.DialOption - if parsedURL.Scheme == "https" { + tlsEnabled := parsedURL.Scheme == "https" + if tlsEnabled { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(), + nbgrpc.WithCustomDialer(tlsEnabled), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/go.mod b/go.mod index 23aa45277..c4b629993 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.13 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/lib/v4 v4.2.0 diff --git a/go.sum b/go.sum index 7096be3fe..13838b82d 100644 --- a/go.sum +++ b/go.sum @@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 984a56a39..ddd81daa2 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -10,6 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { @@ -56,8 +58,8 @@ func (s *BaseServer) AuthManager() auth.Manager { }) } -func (s *BaseServer) EphemeralManager() *server.EphemeralManager { - return Create(s, func() *server.EphemeralManager { - return server.NewEphemeralManager(s.Store(), s.AccountManager()) +func (s *BaseServer) EphemeralManager() ephemeral.Manager { + return Create(s, func() ephemeral.Manager { + return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 70f0f93a9..daec4ef6f 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -65,6 +65,10 @@ func (s *BaseServer) AccountManager() account.Manager { if err != nil { log.Fatalf("failed to create account manager: %v", err) } + + s.AfterInit(func(s *BaseServer) { + accountManager.SetEphemeralManager(s.EphemeralManager()) + }) return accountManager }) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index e868c2529..ae9ac4a60 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,12 +6,14 @@ import ( "fmt" "net" "net/http" + "net/netip" "strings" "sync" "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -22,6 +24,8 @@ import ( "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" ) @@ -92,12 +96,6 @@ func (s *BaseServer) Start(ctx context.Context) error { s.PeersManager() s.GeoLocationManager() - for _, fn := range s.afterInit { - if fn != nil { - fn(s) - } - } - err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics") if err != nil { return fmt.Errorf("failed to expose metrics: %v", err) @@ -147,7 +145,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler()) + rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -176,6 +174,12 @@ func (s *BaseServer) Start(ctx context.Context) error { } } + for _, fn := range s.afterInit { + if fn != nil { + fn(s) + } + } + log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) @@ -247,13 +251,17 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) return util.DirectWriteJson(ctx, path, config) } -func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler { +func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto") - if request.ProtoMajor == 2 && grpcHeader { + switch { + case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || + strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - } else { + case request.URL.Path == wsproxy.ProxyPath: + wsProxy.Handler().ServeHTTP(writer, request) + default: httpHandler.ServeHTTP(writer, request) } }) diff --git a/management/server/account.go b/management/server/account.go index ee9f294a4..dca105ddf 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -74,6 +75,7 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation + ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -261,6 +263,10 @@ func BuildManager( return am, nil } +func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { + am.ephemeralManager = em +} + func (am *DefaultAccountManager) startWarmup(ctx context.Context) { var initialInterval int64 intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 30fbbbc3e..a1ed9498b 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -12,6 +12,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -56,7 +57,7 @@ type Manager interface { UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) @@ -125,5 +126,6 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 81a921bf9..07d2f2383 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -66,7 +66,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account setupKey = key.Key } - _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -1048,10 +1048,10 @@ func TestAccountManager_AddPeer(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1112,10 +1112,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) return @@ -1429,10 +1429,10 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1805,11 +1805,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1861,11 +1861,11 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, @@ -1904,11 +1904,11 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2952,14 +2952,14 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Status: &nbpeer.PeerStatus{ Connected: true, LastSeen: time.Now().UTC(), }, - }) + }, false) if err != nil { t.Fatalf("expecting peer to be added, got failure %v", err) } @@ -3552,16 +3552,16 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { key2, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) require.NoError(t, err, "unable to add peer1") - peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) require.NoError(t, err, "unable to add peer2") t.Run("update peer IP successfully", func(t *testing.T) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 55a1bbe66..a2a2ce529 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -281,11 +281,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 60a00207e..1177eefff 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -22,6 +22,7 @@ import ( integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,7 +56,7 @@ type GRPCServer struct { config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager + ephemeralManager ephemeral.Manager peerLocks sync.Map authManager auth.Manager @@ -73,7 +74,7 @@ func NewServer( peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager *EphemeralManager, + ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, ) (*GRPCServer, error) { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index af501e151..4b33495de 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -32,6 +32,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") } // NewHandler creates a new peers Handler @@ -318,6 +319,88 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } +func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + var req api.PeerTemporaryAccessRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + newPeer := &nbpeer.Peer{} + newPeer.FromAPITemporaryAccessRequest(&req) + + targetPeer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + for _, rule := range req.Rules { + protocol, portRange, err := types.ParseRuleString(rule) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + policy := &types.Policy{ + AccountID: userAuth.AccountId, + Description: "Temporary access policy for peer " + peer.Name, + Name: "Temporary access policy for peer " + peer.Name, + Enabled: true, + Rules: []*types.PolicyRule{{ + Name: "Temporary access rule", + Description: "Temporary access rule", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + SourceResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: peer.ID, + }, + DestinationResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: targetPeer.ID, + }, + Bidirectional: false, + Protocol: protocol, + PortRanges: []types.RulePortRange{portRange}, + }}, + } + + _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + } + + resp := &api.PeerTemporaryAccessResponse{ + Id: peer.ID, + Name: peer.Name, + Rules: req.Rules, + } + + util.WriteJSONObject(r.Context(), w, resp) +} + func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ba4997d22..a34d2086b 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -460,7 +461,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - ephemeralMgr := NewEphemeralManager(store, accountManager) + ephemeralMgr := manager.NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) if err != nil { return nil, nil, "", cleanup, err diff --git a/management/server/management_test.go b/management/server/management_test.go index 61dc46d87..1a5e47354 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -228,7 +229,7 @@ func startServer( peersUpdateManager, secretsManager, nil, - nil, + &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, ) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 003385eb5..d160e7269 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,6 +15,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -41,7 +42,7 @@ type MockAccountManager struct { DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) @@ -351,12 +352,14 @@ func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( ctx context.Context, + accountID string, setupKey string, userId string, peer *nbpeer.Peer, + temporary bool, ) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { - return am.AddPeerFunc(ctx, setupKey, userId, peer) + return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } @@ -972,6 +975,11 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } +// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface +func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { + // Mock implementation - does nothing +} + func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 959e7856a..6c985410c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -876,11 +876,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 294f51676..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -132,7 +132,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc res := nbtypes.Resource{ ID: resource.ID, - Type: resource.Type.String(), + Type: nbtypes.ResourceType(resource.Type.String()), } for _, groupID := range resource.GroupIDs { event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) @@ -265,7 +265,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { res := nbtypes.Resource{ ID: newResource.ID, - Type: newResource.Type.String(), + Type: nbtypes.ResourceType(newResource.Type.String()), } oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) diff --git a/management/server/peer.go b/management/server/peer.go index 81f037499..ea4617af0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -450,7 +450,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -482,8 +482,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var ephemeral bool var groupsToAdd []string var allowExtraDNSLabels bool - var accountID string - var isEphemeral bool if addedByUser { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -492,10 +490,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if user.PendingApproval { return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") } - groupsToAdd = user.AutoGroups + if temporary { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create) + if err != nil { + return nil, nil, nil, status.NewPermissionValidationError(err) + } + + if !allowed { + return nil, nil, nil, status.NewPermissionDeniedError() + } + } else { + accountID = user.AccountID + groupsToAdd = user.AutoGroups + } opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser - accountID = user.AccountID } else { // Validate the setup key sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey) @@ -516,13 +525,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name allowExtraDNSLabels = sk.AllowExtraDNSLabels accountID = sk.AccountID - isEphemeral = sk.Ephemeral if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } } opEvent.AccountID = accountID + if temporary { + ephemeral = true + } + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if am.idpManager != nil { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) @@ -549,10 +561,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s SSHKey: peer.SSHKey, LastLogin: ®istrationTime, CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, + LoginExpirationEnabled: addedByUser && !temporary, Ephemeral: ephemeral, Location: peer.Location, - InactivityExpirationEnabled: addedByUser, + InactivityExpirationEnabled: addedByUser && !temporary, ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } @@ -588,7 +600,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var freeLabel string - if isEphemeral || attempt > 1 { + if ephemeral || attempt > 1 { freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) @@ -622,6 +634,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed adding peer to All group: %w", err) } + if temporary { + // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually + am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -790,7 +807,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo ExtraDNSLabels: login.ExtraDNSLabels, } - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + return am.AddPeer(ctx, "", login.SetupKey, login.UserID, newPeer, false) } log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) @@ -877,6 +894,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer if peer.SSHKey != login.SSHKey { peer.SSHKey = login.SSHKey shouldStorePeer = true + updateRemotePeers = true } if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 { @@ -1540,6 +1558,26 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peer.ID) + if err != nil { + return nil, err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 6a6d1c91d..f89f10dac 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -8,6 +8,7 @@ import ( "time" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Peer represents a machine connected to the network. @@ -334,6 +335,17 @@ func (p *Peer) UpdateLastLogin() *Peer { return p } +func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) { + p.Ephemeral = true + p.Name = a.Name + p.Key = a.WgPubKey + p.Meta = PeerSystemMeta{ + Hostname: a.Name, + GoOS: "js", + OS: "js", + } +} + func (f Flags) isEqual(other Flags) bool { return f.RosenpassEnabled == other.RosenpassEnabled && f.RosenpassPermissive == other.RosenpassPermissive && diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 31c309430..734536d7b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -193,10 +193,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -207,10 +207,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -266,10 +266,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -280,10 +280,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -442,10 +442,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -456,10 +456,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -514,10 +514,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -530,10 +530,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Fatal(err) return @@ -702,19 +702,19 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return } - _, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1300,7 +1300,7 @@ func Test_RegisterPeerByUser(t *testing.T) { }, } - addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false) require.NoError(t, err) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) @@ -1422,7 +1422,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels, } - addedPeer, _, _, err := am.AddPeer(context.Background(), tc.existingSetupKeyID, "", currentPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false) if tc.expectAddPeerError { require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID) @@ -1523,7 +1523,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { SSHEnabled: false, } - _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + _, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false) require.Error(t, err) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key) @@ -1658,7 +1658,7 @@ func Test_LoginPeer(t *testing.T) { if sk.AllowExtraDNSLabels { currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels } - _, _, _, err = am.AddPeer(context.Background(), tc.setupKey, "", currentPeer) + _, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false) require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey) loginInput := types.PeerLogin{ @@ -1797,10 +1797,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1918,11 +1918,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1982,11 +1982,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2037,11 +2037,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2208,7 +2208,7 @@ func Test_AddPeer(t *testing.T) { <-start - _, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false) if err != nil { errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return @@ -2416,7 +2416,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false) require.Error(t, err) assert.Contains(t, err.Error(), "user pending approval cannot add peers") } @@ -2451,7 +2451,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false) require.NoError(t, err, "Regular user should be able to add peers") } @@ -2494,7 +2494,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false) require.NoError(t, err) // Now set the user back to pending approval after peer was created @@ -2550,7 +2550,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false) require.NoError(t, err) // Try to login with regular user diff --git a/management/server/peers/ephemeral/interface.go b/management/server/peers/ephemeral/interface.go new file mode 100644 index 000000000..a1605b3b9 --- /dev/null +++ b/management/server/peers/ephemeral/interface.go @@ -0,0 +1,14 @@ +package ephemeral + +import ( + "context" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +type Manager interface { + LoadInitialPeers(ctx context.Context) + Stop() + OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) + OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) +} diff --git a/management/server/ephemeral.go b/management/server/peers/ephemeral/manager/ephemeral.go similarity index 99% rename from management/server/ephemeral.go rename to management/server/peers/ephemeral/manager/ephemeral.go index e3cb5459a..062ba69d2 100644 --- a/management/server/ephemeral.go +++ b/management/server/peers/ephemeral/manager/ephemeral.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" diff --git a/management/server/ephemeral_test.go b/management/server/peers/ephemeral/manager/ephemeral_test.go similarity index 75% rename from management/server/ephemeral_test.go rename to management/server/peers/ephemeral/manager/ephemeral_test.go index d07b9a422..fc7525c29 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/peers/ephemeral/manager/ephemeral_test.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" @@ -7,12 +7,15 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" ) type MockStore struct { @@ -223,3 +226,57 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) store.account.Peers[p.ID] = p } } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[route.ID]*route.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, + }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, + } + + if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} diff --git a/management/server/policy.go b/management/server/policy.go index 3adee6397..9e4b3f73a 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -151,6 +151,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return false, nil } + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err @@ -161,6 +167,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a } } + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 027938320..382d026c8 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2037,6 +2037,25 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) }) } +func (s *SqlStore) GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, resourceID string) ([]*types.PolicyRule, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var policyRules []*types.PolicyRule + resourceIDPattern := `%"ID":"` + resourceID + `"%` + result := tx.Where("source_resource LIKE ? OR destination_resource LIKE ?", resourceIDPattern, resourceIDPattern). + Find(&policyRules) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get policy rules for resource id from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policy rules for resource id from store") + } + + return policyRules, nil +} + // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 3c9d896b0..21b660d96 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -202,6 +202,7 @@ type Store interface { IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error + GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index a69d3bb08..f830023c7 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1001,8 +1001,20 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID) + } else { + sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID) + } else { + destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + } if rule.Bidirectional { if peerInSources { @@ -1124,6 +1136,15 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe return filteredPeers, peerInGroups } +func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + peer := a.GetPeer(resource.ID) + if peer == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peer}, resource.ID == peerID +} + // validatePostureChecksOnPeer validates the posture checks on a peer func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] @@ -1379,7 +1400,12 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } if addSourcePeers { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { allSourcePeers[pID] = struct{}{} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 17964ed1f..5e86a87c6 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,12 @@ package types +import ( + "errors" + "fmt" + "strconv" + "strings" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -134,3 +141,83 @@ func (p *Policy) SourceGroups() []string { return groupIDs } + +func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) { + rule = strings.TrimSpace(strings.ToLower(rule)) + if rule == "all" { + return PolicyRuleProtocolALL, RulePortRange{}, nil + } + if rule == "icmp" { + return PolicyRuleProtocolICMP, RulePortRange{}, nil + } + + split := strings.Split(rule, "/") + if len(split) != 2 { + return "", RulePortRange{}, errors.New("invalid rule format: expected protocol/port or protocol/port-range") + } + + protoStr := strings.TrimSpace(split[0]) + portStr := strings.TrimSpace(split[1]) + + var protocol PolicyRuleProtocolType + switch protoStr { + case "tcp": + protocol = PolicyRuleProtocolTCP + case "udp": + protocol = PolicyRuleProtocolUDP + case "icmp": + return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + default: + return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) + } + + portRange, err := parsePortRange(portStr) + if err != nil { + return "", RulePortRange{}, err + } + + return protocol, portRange, nil +} + +func parsePortRange(portStr string) (RulePortRange, error) { + if strings.Contains(portStr, "-") { + rangeParts := strings.Split(portStr, "-") + if len(rangeParts) != 2 { + return RulePortRange{}, fmt.Errorf("invalid port range %q", portStr) + } + start, err := parsePort(strings.TrimSpace(rangeParts[0])) + if err != nil { + return RulePortRange{}, err + } + end, err := parsePort(strings.TrimSpace(rangeParts[1])) + if err != nil { + return RulePortRange{}, err + } + if start > end { + return RulePortRange{}, fmt.Errorf("invalid port range: start %d > end %d", start, end) + } + return RulePortRange{Start: uint16(start), End: uint16(end)}, nil + } + + p, err := parsePort(portStr) + if err != nil { + return RulePortRange{}, err + } + + return RulePortRange{Start: uint16(p), End: uint16(p)}, nil +} + +func parsePort(portStr string) (int, error) { + + if portStr == "" { + return 0, errors.New("empty port") + } + p, err := strconv.Atoi(portStr) + if err != nil { + return 0, fmt.Errorf("invalid port %q: %w", portStr, err) + } + if p < 1 || p > 65535 { + return 0, fmt.Errorf("port out of range (1–65535): %d", p) + } + return p, nil +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go index 84d8e4b88..8347d8c03 100644 --- a/management/server/types/resource.go +++ b/management/server/types/resource.go @@ -4,9 +4,18 @@ import ( "github.com/netbirdio/netbird/shared/management/http/api" ) +type ResourceType string + +const ( + ResourceTypePeer ResourceType = "peer" + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + type Resource struct { ID string - Type string + Type ResourceType } func (r *Resource) ToAPIResponse() *api.Resource { @@ -26,5 +35,5 @@ func (r *Resource) FromAPIRequest(req *api.Resource) { } r.ID = req.Id - r.Type = string(req.Type) + r.Type = ResourceType(req.Type) } diff --git a/management/server/user_test.go b/management/server/user_test.go index 9638559f9..5920a2a33 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1439,10 +1439,10 @@ func TestUserAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) // updating user with linked peers should update account peers and send peer update diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index becc10ded..d4a9f1823 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/server/config" @@ -27,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +119,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 9a531b2ff..93578b1ae 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -507,6 +507,48 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerTemporaryAccessRequest: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + wg_pub_key: + description: Peer's WireGuard public key + type: string + example: "n0r3pL4c3h0ld3rK3y==" + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - wg_pub_key + - rules + PeerTemporaryAccessResponse: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + id: + description: Peer ID + type: string + example: chacbco6lnnbn6cg5s90 + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - id + - rules AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -1404,7 +1446,8 @@ components: allOf: - $ref: '#/components/schemas/NetworkResourceType' - type: string - example: host + enum: ["peer"] + example: peer NetworkRequest: type: object properties: @@ -2793,6 +2836,42 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/peers/{peerId}/temporary-access: + post: + summary: Create a Temporary Access Peer + description: Creates a temporary access peer that can be used to access this peer and this peer only. The temporary access peer and its access policies will be automatically deleted after it disconnects. + tags: [ Peers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: peerId + required: true + schema: + type: string + description: The unique identifier of a peer + requestBody: + description: Temporary Access Peer create request + content: + 'application/json': + schema: + $ref: '#/components/schemas/PeerTemporaryAccessRequest' + responses: + '200': + description: Temporary Access Peer response + content: + application/json: + schema: + $ref: '#/components/schemas/PeerTemporaryAccessResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers/{peerId}/ingress/ports: get: x-cloud-only: true diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 28b89633c..3dbb32ef6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -168,6 +168,7 @@ const ( const ( ResourceTypeDomain ResourceType = "domain" ResourceTypeHost ResourceType = "host" + ResourceTypePeer ResourceType = "peer" ResourceTypeSubnet ResourceType = "subnet" ) @@ -1221,6 +1222,30 @@ type PeerRequest struct { SshEnabled bool `json:"ssh_enabled"` } +// PeerTemporaryAccessRequest defines model for PeerTemporaryAccessRequest. +type PeerTemporaryAccessRequest struct { + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` + + // WgPubKey Peer's WireGuard public key + WgPubKey string `json:"wg_pub_key"` +} + +// PeerTemporaryAccessResponse defines model for PeerTemporaryAccessResponse. +type PeerTemporaryAccessResponse struct { + // Id Peer ID + Id string `json:"id"` + + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` +} + // PersonalAccessToken defines model for PersonalAccessToken. type PersonalAccessToken struct { // CreatedAt Date the token was created @@ -1949,6 +1974,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques // PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType. type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest +// PostApiPeersPeerIdTemporaryAccessJSONRequestBody defines body for PostApiPeersPeerIdTemporaryAccess for application/json ContentType. +type PostApiPeersPeerIdTemporaryAccessJSONRequestBody = PeerTemporaryAccessRequest + // PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType. type PostApiPoliciesJSONRequestBody = PolicyUpdate diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index bf614e8aa..8381d6682 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v6.32.0 // source: management.proto package proto diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 5dabc5742..57a98614d 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -9,11 +9,8 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client/dialer" - "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" - "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" ) @@ -296,14 +293,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - // Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues - var dialers []dialer.DialeFn - if c.mtu > 0 && c.mtu > iface.DefaultMTU { - c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) - dialers = []dialer.DialeFn{ws.Dialer{}} - } else { - dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} - } + dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) conn, err := rd.Dial() diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index 0086b702b..d5b719f51 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Write(b []byte) (n int, err error) { - err = c.Conn.Write(c.ctx, websocket.MessageBinary, b) - return 0, err + return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b) } func (c *Conn) RemoteAddr() net.Addr { diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go new file mode 100644 index 000000000..9dfe698d0 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -0,0 +1,11 @@ +//go:build !js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + return &websocket.DialOptions{ + HTTPClient: httpClientNbDialer(), + } +} diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go new file mode 100644 index 000000000..7eac27531 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -0,0 +1,10 @@ +//go:build js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + // WASM version doesn't support HTTPClient + return &websocket.DialOptions{} +} diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index ef6bd6b3c..66fff3447 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } - opts := &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), - } + opts := createDialOptions() parsedURL, err := url.Parse(wsURL) if err != nil { diff --git a/shared/relay/client/dialers_generic.go b/shared/relay/client/dialers_generic.go new file mode 100644 index 000000000..a8ed79961 --- /dev/null +++ b/shared/relay/client/dialers_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package client + +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +// getDialers returns the list of dialers to use for connecting to the relay server. +func (c *Client) getDialers() []dialer.DialeFn { + if c.mtu > 0 && c.mtu > iface.DefaultMTU { + c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) + return []dialer.DialeFn{ws.Dialer{}} + } + return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} +} diff --git a/shared/relay/client/dialers_js.go b/shared/relay/client/dialers_js.go new file mode 100644 index 000000000..6bd0e6696 --- /dev/null +++ b/shared/relay/client/dialers_js.go @@ -0,0 +1,13 @@ +//go:build js + +package client + +import ( + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +func (c *Client) getDialers() []dialer.DialeFn { + // JS/WASM build only uses WebSocket transport + return []dialer.DialeFn{ws.Dialer{}} +} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 1d76fa4e4..e2a69a75b 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,14 +8,16 @@ import ( "fmt" "net" "net/http" - // nolint:gosec _ "net/http/pprof" - "strings" + "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "github.com/netbirdio/netbird/signal/metrics" @@ -23,6 +25,8 @@ import ( "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/server" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" log "github.com/sirupsen/logrus" @@ -32,6 +36,8 @@ import ( "google.golang.org/grpc/keepalive" ) +const legacyGRPCPort = 10000 + var ( signalPort int metricsPort int @@ -113,7 +119,7 @@ var ( } proto.RegisterSignalExchangeServer(grpcServer, srv) - grpcRootHandler := grpcHandlerFunc(grpcServer) + grpcRootHandler := grpcHandlerFunc(grpcServer, metricsServer.Meter) if certManager != nil { startServerWithCertManager(certManager, grpcRootHandler) @@ -123,19 +129,30 @@ var ( var grpcListener net.Listener var httpListener net.Listener - // If certManager is configured and signalPort == 443, then the gRPC server has already been started - if certManager == nil || signalPort != 443 { - grpcListener, err = serveGRPC(grpcServer, signalPort) + // Start the main server - always serve HTTP with WebSocket proxy support + // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager + if certManager == nil { + // Without TLS, serve plain HTTP + httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { return err } - log.Infof("running gRPC server: %s", grpcListener.Addr().String()) + log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) + } else if signalPort != 443 { + // With TLS but not on port 443, serve HTTPS + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + if err != nil { + return err + } + log.Infof("running HTTPS server with WebSocket proxy: %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) } - if signalPort != 10000 { + if signalPort != legacyGRPCPort { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. - compatListener, err = serveGRPC(grpcServer, 10000) + compatListener, err = serveGRPC(grpcServer, legacyGRPCPort) if err != nil { return err } @@ -236,11 +253,14 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } } -func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { +func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto") - if r.ProtoMajor == 2 && grpcHeader { + switch { + case r.URL.Path == wsproxy.ProxyPath: + wsProxy.Handler().ServeHTTP(w, r) + default: grpcServer.ServeHTTP(w, r) } }) @@ -257,7 +277,11 @@ func notifyStop(msg string) { func serveHTTP(httpListener net.Listener, handler http.Handler) { go func() { - err := http.Serve(httpListener, handler) + // Use h2c to support HTTP/2 without TLS (needed for gRPC) + h1s := &http.Server{ + Handler: h2c.NewHandler(handler, &http2.Server{}), + } + err := h1s.Serve(httpListener) if err != nil { notifyStop(fmt.Sprintf("failed running HTTP server %v", err)) } diff --git a/util/util_js.go b/util/util_js.go new file mode 100644 index 000000000..8c243cab3 --- /dev/null +++ b/util/util_js.go @@ -0,0 +1,8 @@ +//go:build js + +package util + +// IsAdmin returns false for WASM as there's no admin concept in browser +func IsAdmin() bool { + return false +} diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go new file mode 100644 index 000000000..2caeed025 --- /dev/null +++ b/util/wsproxy/client/dialer_js.go @@ -0,0 +1,171 @@ +package client + +import ( + "context" + "fmt" + "net" + "sync" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const dialTimeout = 30 * time.Second + +// websocketConn wraps a JavaScript WebSocket to implement net.Conn +type websocketConn struct { + ws js.Value + remoteAddr string + messages chan []byte + readBuf []byte + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +func (c *websocketConn) Read(b []byte) (int, error) { + c.mu.Lock() + if len(c.readBuf) > 0 { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + c.mu.Unlock() + return n, nil + } + c.mu.Unlock() + + select { + case data := <-c.messages: + n := copy(b, data) + if n < len(data) { + c.mu.Lock() + c.readBuf = data[n:] + c.mu.Unlock() + } + return n, nil + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } +} + +func (c *websocketConn) Write(b []byte) (int, error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + } + + uint8Array := js.Global().Get("Uint8Array").New(len(b)) + js.CopyBytesToJS(uint8Array, b) + c.ws.Call("send", uint8Array) + return len(b), nil +} + +func (c *websocketConn) Close() error { + c.cancel() + c.ws.Call("close") + return nil +} + +func (c *websocketConn) LocalAddr() net.Addr { + return nil +} + +func (c *websocketConn) RemoteAddr() net.Addr { + return stringAddr(c.remoteAddr) +} +func (c *websocketConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// stringAddr is a simple net.Addr that returns a string +type stringAddr string + +func (s stringAddr) Network() string { return "tcp" } +func (s stringAddr) String() string { return string(s) } + +// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. +func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + scheme := "wss" + if !tlsEnabled { + scheme = "ws" + } + wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) + + ws := js.Global().Get("WebSocket").New(wsURL) + + connCtx, connCancel := context.WithCancel(context.Background()) + conn := &websocketConn{ + ws: ws, + remoteAddr: addr, + messages: make(chan []byte, 100), + ctx: connCtx, + cancel: connCancel, + } + + ws.Set("binaryType", "arraybuffer") + + openCh := make(chan struct{}) + errorCh := make(chan error, 1) + + ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any { + close(openCh) + return nil + })) + + ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any { + select { + case errorCh <- wsproxy.ErrConnectionFailed: + default: + } + return nil + })) + + ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any { + event := args[0] + data := event.Get("data") + + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + + select { + case conn.messages <- bytes: + default: + log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr) + } + return nil + })) + + ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any { + conn.cancel() + return nil + })) + + select { + case <-openCh: + return conn, nil + case err := <-errorCh: + return nil, err + case <-ctx.Done(): + ws.Call("close") + return nil, ctx.Err() + case <-time.After(dialTimeout): + ws.Call("close") + return nil, wsproxy.ErrConnectionTimeout + } + }) +} diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go new file mode 100644 index 000000000..8d117c7d9 --- /dev/null +++ b/util/wsproxy/constants.go @@ -0,0 +1,13 @@ +package wsproxy + +import "errors" + +// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. +const ProxyPath = "/ws-proxy" + +// Common errors +var ( + ErrConnectionTimeout = errors.New("WebSocket connection timeout") + ErrConnectionFailed = errors.New("WebSocket connection failed") + ErrBackendUnavailable = errors.New("backend unavailable") +) diff --git a/util/wsproxy/server/metrics.go b/util/wsproxy/server/metrics.go new file mode 100644 index 000000000..dd3b96dad --- /dev/null +++ b/util/wsproxy/server/metrics.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// MetricsRecorder defines the interface for recording proxy metrics +type MetricsRecorder interface { + // RecordConnection records a new connection + RecordConnection(ctx context.Context) + // RecordDisconnection records a connection closing + RecordDisconnection(ctx context.Context) + // RecordBytesTransferred records bytes transferred in a direction + RecordBytesTransferred(ctx context.Context, direction string, bytes int64) + // RecordError records an error + RecordError(ctx context.Context, errorType string) +} + +// NoOpMetricsRecorder is a no-op implementation that does nothing +type NoOpMetricsRecorder struct{} + +func (n NoOpMetricsRecorder) RecordConnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordDisconnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + // no-op +} +func (n NoOpMetricsRecorder) RecordError(ctx context.Context, errorType string) { + // no-op +} + +// Recorder implements MetricsRecorder using OpenTelemetry +type Recorder struct { + activeConnections metric.Int64UpDownCounter + bytesTransferred metric.Int64Counter + errors metric.Int64Counter +} + +// NewMetricsRecorder creates a new OpenTelemetry-based metrics recorder +func NewMetricsRecorder(meter metric.Meter) (*Recorder, error) { + activeConnections, err := meter.Int64UpDownCounter( + "wsproxy_active_connections", + metric.WithDescription("Number of active WebSocket proxy connections"), + ) + if err != nil { + return nil, err + } + + bytesTransferred, err := meter.Int64Counter( + "wsproxy_bytes_transferred_total", + metric.WithDescription("Total bytes transferred through the proxy"), + ) + if err != nil { + return nil, err + } + + errors, err := meter.Int64Counter( + "wsproxy_errors_total", + metric.WithDescription("Total number of proxy errors"), + ) + if err != nil { + return nil, err + } + + return &Recorder{ + activeConnections: activeConnections, + bytesTransferred: bytesTransferred, + errors: errors, + }, nil +} + +func (o *Recorder) RecordConnection(ctx context.Context) { + o.activeConnections.Add(ctx, 1) +} + +func (o *Recorder) RecordDisconnection(ctx context.Context) { + o.activeConnections.Add(ctx, -1) +} + +func (o *Recorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + o.bytesTransferred.Add(ctx, bytes, metric.WithAttributes( + attribute.String("direction", direction), + )) +} + +func (o *Recorder) RecordError(ctx context.Context, errorType string) { + o.errors.Add(ctx, 1, metric.WithAttributes( + attribute.String("error_type", errorType), + )) +} + +// Option defines functional options for the Proxy +type Option func(*Config) + +// WithMetrics sets a custom metrics recorder +func WithMetrics(recorder MetricsRecorder) Option { + return func(c *Config) { + c.MetricsRecorder = recorder + } +} + +// WithOTelMeter creates and sets an OpenTelemetry metrics recorder +func WithOTelMeter(meter metric.Meter) Option { + return func(c *Config) { + if recorder, err := NewMetricsRecorder(meter); err == nil { + c.MetricsRecorder = recorder + } else { + log.Warnf("Failed to create OTel metrics recorder: %v", err) + } + } +} diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go new file mode 100644 index 000000000..977440a60 --- /dev/null +++ b/util/wsproxy/server/proxy.go @@ -0,0 +1,227 @@ +package server + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/netip" + "sync" + "time" + + "github.com/coder/websocket" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const ( + dialTimeout = 10 * time.Second + bufferSize = 32 * 1024 +) + +// Config contains the configuration for the WebSocket proxy. +type Config struct { + LocalGRPCAddr netip.AddrPort + Path string + MetricsRecorder MetricsRecorder +} + +// Proxy handles WebSocket to TCP proxying for gRPC connections. +type Proxy struct { + config Config + metrics MetricsRecorder +} + +// New creates a new WebSocket proxy instance with optional configuration +func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { + config := Config{ + LocalGRPCAddr: localGRPCAddr, + Path: wsproxy.ProxyPath, + MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op + } + + for _, opt := range opts { + opt(&config) + } + + return &Proxy{ + config: config, + metrics: config.MetricsRecorder, + } +} + +// Handler returns an http.Handler that proxies WebSocket connections to the local gRPC server. +func (p *Proxy) Handler() http.Handler { + return http.HandlerFunc(p.handleWebSocket) +} + +func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + p.metrics.RecordConnection(ctx) + defer p.metrics.RecordDisconnection(ctx) + + log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) + if err != nil { + p.metrics.RecordError(ctx, "websocket_accept_failed") + log.Errorf("WebSocket upgrade failed from %s: %v", r.RemoteAddr, err) + return + } + defer func() { + if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { + log.Debugf("Failed to close WebSocket: %v", err) + } + }() + + log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) + tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) + if err != nil { + p.metrics.RecordError(ctx, "tcp_dial_failed") + log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) + if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { + log.Debugf("Failed to close WebSocket after connection failure: %v", err) + } + return + } + defer func() { + if err := tcpConn.Close(); err != nil { + log.Debugf("Failed to close TCP connection: %v", err) + } + }() + + log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + + p.proxyData(ctx, wsConn, tcpConn) +} + +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { + proxyCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + wg.Add(2) + + go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + log.Tracef("Proxy data transfer completed, both goroutines terminated") + case <-proxyCtx.Done(): + log.Tracef("Proxy data transfer cancelled, forcing connection closure") + + if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { + log.Tracef("Error closing WebSocket during cancellation: %v", err) + } + if err := tcpConn.Close(); err != nil { + log.Tracef("Error closing TCP connection during cancellation: %v", err) + } + + select { + case <-done: + log.Tracef("Goroutines terminated after forced connection closure") + case <-time.After(2 * time.Second): + log.Tracef("Goroutines did not terminate within timeout after connection closure") + } + } +} + +func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + for { + msgType, data, err := wsConn.Read(ctx) + if err != nil { + switch { + case ctx.Err() != nil: + log.Debugf("wsToTCP goroutine terminating due to context cancellation") + case websocket.CloseStatus(err) == websocket.StatusNormalClosure: + log.Debugf("WebSocket closed normally") + default: + p.metrics.RecordError(ctx, "websocket_read_error") + log.Errorf("WebSocket read error: %v", err) + } + return + } + + if msgType != websocket.MessageBinary { + log.Warnf("Unexpected WebSocket message type: %v", msgType) + continue + } + + if ctx.Err() != nil { + log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + return + } + + if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP write deadline: %v", err) + } + + n, err := tcpConn.Write(data) + if err != nil { + p.metrics.RecordError(ctx, "tcp_write_error") + log.Errorf("TCP write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + } +} + +func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + buf := make([]byte, bufferSize) + for { + if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP read deadline: %v", err) + } + n, err := tcpConn.Read(buf) + + if err != nil { + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation") + return + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + continue + } + + if err != io.EOF { + log.Errorf("TCP read error: %v", err) + } + return + } + + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + return + } + + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Errorf("WebSocket write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + } +}