[client,signal,management] Add browser client support (#4415)

This commit is contained in:
Viktor Liu
2025-10-01 20:10:11 +02:00
committed by GitHub
parent 5e1a40c33f
commit b5daec3b51
107 changed files with 3591 additions and 284 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
skip: go.mod,go.sum
golangci:
strategy:

View File

@@ -0,0 +1,67 @@
name: Wasm
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:
name: "JS / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size
run: |
echo "Wasm build size:"
ls -lh netbird.wasm
SIZE=$(stat -c%s netbird.wasm)
SIZE_MB=$((SIZE / 1024 / 1024))
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

0
.gitmodules vendored Normal file
View File

View File

@@ -2,6 +2,18 @@ version: 2
project_name: netbird
builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird
dir: client
binary: netbird
@@ -115,6 +127,11 @@ archives:
- builds:
- netbird
- netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
nfpms:
- maintainer: Netbird <dev@netbird.io>

8
client/cmd/debug_js.go Normal file
View File

@@ -0,0 +1,8 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
@@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -23,23 +23,29 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
config *profilemanager.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
jwtToken string
connect *internal.ConnectClient
}
// Options configures a new Client
// Options configures a new Client.
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// PrivateKey is used for direct private key authentication
PrivateKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
@@ -58,8 +64,35 @@ type Options struct {
DisableClientRoutes bool
}
// New creates a new netbird embedded client
// validateCredentials checks that exactly one credential type is provided
func (opts *Options) validateCredentials() error {
credentialsProvided := 0
if opts.SetupKey != "" {
credentialsProvided++
}
if opts.JWTToken != "" {
credentialsProvided++
}
if opts.PrivateKey != "" {
credentialsProvided++
}
if credentialsProvided == 0 {
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
}
if credentialsProvided > 1 {
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
}
return nil
}
// New creates a new netbird embedded client.
func New(opts Options) (*Client, error) {
if err := opts.validateCredentials(); err != nil {
return nil, err
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
return nil, fmt.Errorf("create config: %w", err)
}
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
}, nil
}
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
}
}
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// ListenUDP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()

View File

@@ -4,15 +4,9 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os/user"
"runtime"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -20,37 +14,10 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/util/embeddedroots"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
@@ -58,6 +25,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// CreateConnection creates a gRPC client connection with the appropriate transport options
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
@@ -68,7 +36,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
}
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool,
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
}))
}
@@ -79,7 +49,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
connCtx,
addr,
transportOption,
WithCustomDialer(),
WithCustomDialer(tlsEnabled),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

View File

@@ -0,0 +1,44 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

12
client/grpc/dialer_js.go Normal file
View File

@@ -0,0 +1,12 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled)
}

View File

@@ -0,0 +1,7 @@
package bind
import "fmt"
var (
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
)

View File

@@ -1,3 +1,5 @@
//go:build !js
package bind
import (
@@ -21,11 +23,6 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
@@ -43,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
recvChan chan RecvMessage
transportNet transport.Net
filterFn udpmux.FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
address wgaddr.Address
mtu uint16
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
recvChan chan recvMessage
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
recvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn),
recvChan: make(chan recvMessage, 1),
closedChan: make(chan struct{}),
closed: true,
mtu: mtu,
address: address,
activityRecorder: NewActivityRecorder(),
}
@@ -84,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
return ib
}
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
@@ -140,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
delete(b.endpoints, fakeIP)
}
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-b.closedChan:
return
case <-ctx.Done():
return
case b.recvChan <- recvMessage{ep, buf}:
}
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
@@ -156,14 +160,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) {
select {
case <-ctx.Done():
return
case b.recvChan <- msg:
}
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()

View File

@@ -0,0 +1,6 @@
package bind
type recvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

View File

@@ -0,0 +1,125 @@
package bind
import (
"context"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
)
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
// Do not limit to build only js, because we want to be able to run tests
type RelayBindJS struct {
*conn.StdNetBind
recvChan chan recvMessage
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
activityRecorder *ActivityRecorder
ctx context.Context
cancel context.CancelFunc
}
func NewRelayBindJS() *RelayBindJS {
return &RelayBindJS{
recvChan: make(chan recvMessage, 100),
endpoints: make(map[netip.Addr]net.Conn),
activityRecorder: NewActivityRecorder(),
}
}
// Open creates a receive function for handling relay packets in WASM.
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.ctx, s.cancel = context.WithCancel(context.Background())
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
select {
case <-s.ctx.Done():
return 0, net.ErrClosed
case msg, ok := <-s.recvChan:
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
func (s *RelayBindJS) Close() error {
if s.cancel == nil {
return nil
}
log.Debugf("close RelayBindJS")
s.cancel()
return nil
}
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-s.ctx.Done():
return
case <-ctx.Done():
return
case s.recvChan <- recvMessage{ep, buf}:
}
}
// Send forwards packets through the relay connection for WASM.
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
return err
}
}
return nil
}
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
}
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, ErrUDPMUXNotSupported
}
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd
//go:build linux || windows || freebsd || js || wasip1
package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows
//go:build !windows && !js
package configurer

View File

@@ -0,0 +1,23 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -1,9 +1,11 @@
package device
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -15,6 +17,12 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
}
type TunNetstackDevice struct {
name string
address wgaddr.Address
@@ -22,7 +30,7 @@ type TunNetstackDevice struct {
key string
mtu uint16
listenAddress string
iceBind *bind.ICEBind
bind Bind
device *device.Device
filteredDevice *FilteredDevice
@@ -33,7 +41,7 @@ type TunNetstackDevice struct {
net *netstack.Net
}
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: iceBind,
bind: bind,
}
}
@@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
t.bind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
@@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
udpMux, err := t.bind.GetICEMux()
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
return nil, err
}
t.udpMux = udpMux
if udpMux != nil {
t.udpMux = udpMux
}
log.Debugf("netstack device is ready to use")
return udpMux, nil
}

View File

@@ -0,0 +1,27 @@
package device
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestNewNetstackDevice(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
relayBind := bind.NewRelayBindJS()
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
cfgr, err := nsTun.Create()
if err != nil {
t.Fatalf("failed to create netstack device: %v", err)
}
if cfgr == nil {
t.Fatal("expected non-nil configurer")
}
}

View File

@@ -0,0 +1,6 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
@@ -33,7 +33,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,27 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}
return wgIface, nil
}

View File

@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}

View File

@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil

View File

@@ -1,3 +1,5 @@
//go:build !js
package netstack
import (

View File

@@ -0,0 +1,12 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -16,15 +16,14 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type IceBind interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
Recv(ctx context.Context, msg bind.RecvMessage)
MTU() uint16
type Bind interface {
SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
}
type ProxyBind struct {
bind IceBind
bind Bind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
@@ -40,13 +39,15 @@ type ProxyBind struct {
isStarted bool
closeListener *listener.CloseListener
mtu uint16
}
func NewProxyBind(bind IceBind) *ProxyBind {
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
p := &ProxyBind{
bind: bind,
closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
}
return p
@@ -174,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}()
for {
buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead)
buf := make([]byte, p.mtu)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
@@ -190,11 +191,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
msg := bind.RecvMessage{
Endpoint: p.wgCurrentUsed,
Buffer: buf[:n],
}
p.bind.Recv(ctx, msg)
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
p.pausedCond.L.Unlock()
}
}

View File

@@ -3,24 +3,25 @@ package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind *bind.ICEBind
bind proxyBind.Bind
mtu uint16
}
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: iceBind,
bind: bind,
mtu: mtu,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind)
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
func (w *USPFactory) Free() error {

View File

@@ -74,7 +74,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}

View File

@@ -30,7 +30,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}

View File

@@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (hostManager, error) {
return &noopHostConfigurator{}, nil
}

View File

@@ -0,0 +1,19 @@
package dns
import (
"context"
)
type ShutdownState struct{}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
return nil
}

View File

@@ -453,8 +453,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
@@ -466,14 +464,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("initialize dns server: %w", err)
}
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
iceCfg := e.createICEConfig()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr.Start(e.ctx)
@@ -1347,14 +1338,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
Addr: e.getRosenpassAddr(),
PermissiveMode: e.config.RosenpassPermissive,
},
ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
ICEConfig: e.createICEConfig(),
}
serviceDependencies := peer.ServiceDependencies{

View File

@@ -0,0 +1,19 @@
//go:build !js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for non-WASM environments
func (e *Engine) createICEConfig() icemaker.Config {
return icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
}

View File

@@ -0,0 +1,18 @@
//go:build js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for WASM environment.
func (e *Engine) createICEConfig() icemaker.Config {
cfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
return cfg
}

View File

@@ -27,6 +27,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
@@ -42,10 +46,8 @@ import (
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -1584,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -0,0 +1,12 @@
package networkmonitor
import (
"context"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
// No-op for WASM - network changes don't apply
return nil
}

View File

@@ -0,0 +1,48 @@
package systemops
import (
"errors"
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
var ErrRouteNotSupported = errors.New("route operations not supported on js")
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return ErrRouteNotSupported
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return ErrRouteNotSupported
}
func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return ErrRouteNotSupported
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return ErrRouteNotSupported
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error {
return nil
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error {
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !ios
//go:build !linux && !ios && !js
package systemops

View File

@@ -10,23 +10,26 @@ import (
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -314,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import "context"

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

137
client/ssh/ssh_js.go Normal file
View File

@@ -0,0 +1,137 @@
package ssh
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"strings"
"golang.org/x/crypto/ssh"
)
var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment")
// Server is a dummy SSH server interface for WASM.
type Server interface {
Start() error
Stop() error
EnableSSH(enabled bool)
AddAuthorizedKey(peer string, key string) error
RemoveAuthorizedKey(key string)
}
type dummyServer struct{}
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return &dummyServer{}, nil
}
func NewServer(addr string) Server {
return &dummyServer{}
}
func (s *dummyServer) Start() error {
return ErrSSHNotSupported
}
func (s *dummyServer) Stop() error {
return nil
}
func (s *dummyServer) EnableSSH(enabled bool) {
}
func (s *dummyServer) AddAuthorizedKey(peer string, key string) error {
return nil
}
func (s *dummyServer) RemoveAuthorizedKey(key string) {
}
type Client struct{}
func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) {
return nil, ErrSSHNotSupported
}
func (c *Client) Close() error {
return nil
}
func (c *Client) Run(command []string) error {
return ErrSSHNotSupported
}
type SessionRecorder struct{}
func NewSessionRecorder() *SessionRecorder {
return &SessionRecorder{}
}
func (r *SessionRecorder) Record(session string, data []byte) {
}
func GetUserShell() string {
return "/bin/sh"
}
func LookupUserInfo(username string) (string, string, error) {
return "", "", ErrSSHNotSupported
}
const DefaultSSHPort = 44338
const ED25519 = "ed25519"
func isRoot() bool {
return false
}
func GeneratePrivateKey(keyType string) ([]byte, error) {
if keyType != ED25519 {
return nil, errors.New("only ED25519 keys are supported in WASM")
}
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, err
}
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Bytes,
}
pemBytes := pem.EncodeToMemory(pemBlock)
return pemBytes, nil
}
func GeneratePublicKey(privateKey []byte) ([]byte, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
block, _ := pem.Decode(privateKey)
if block != nil {
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
signer, err = ssh.NewSignerFromKey(key)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey())
return []byte(strings.TrimSpace(string(pubKeyBytes))), nil
}

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

231
client/system/info_js.go Normal file
View File

@@ -0,0 +1,231 @@
package system
import (
"context"
"runtime"
"strings"
"syscall/js"
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on JS as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves system information for WASM environment
func GetInfo(_ context.Context) *Info {
info := &Info{
GoOS: runtime.GOOS,
Kernel: runtime.GOARCH,
KernelVersion: runtime.GOARCH,
Platform: runtime.GOARCH,
OS: runtime.GOARCH,
Hostname: "wasm-client",
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
}
collectBrowserInfo(info)
collectLocationInfo(info)
collectSystemInfo(info)
return info
}
func collectBrowserInfo(info *Info) {
navigator := js.Global().Get("navigator")
if navigator.IsUndefined() {
return
}
collectUserAgent(info, navigator)
collectPlatform(info, navigator)
collectCPUInfo(info, navigator)
}
func collectUserAgent(info *Info, navigator js.Value) {
ua := navigator.Get("userAgent")
if ua.IsUndefined() {
return
}
userAgent := ua.String()
os, osVersion := parseOSFromUserAgent(userAgent)
if os != "" {
info.OS = os
}
if osVersion != "" {
info.OSVersion = osVersion
}
}
func collectPlatform(info *Info, navigator js.Value) {
// Try regular platform property
if plat := navigator.Get("platform"); !plat.IsUndefined() {
if platStr := plat.String(); platStr != "" {
info.Platform = platStr
}
}
// Try newer userAgentData API for more accurate platform
userAgentData := navigator.Get("userAgentData")
if userAgentData.IsUndefined() {
return
}
platformInfo := userAgentData.Get("platform")
if !platformInfo.IsUndefined() {
if platStr := platformInfo.String(); platStr != "" {
info.Platform = platStr
}
}
}
func collectCPUInfo(info *Info, navigator js.Value) {
hardwareConcurrency := navigator.Get("hardwareConcurrency")
if !hardwareConcurrency.IsUndefined() {
info.CPUs = hardwareConcurrency.Int()
}
}
func collectLocationInfo(info *Info) {
location := js.Global().Get("location")
if location.IsUndefined() {
return
}
if host := location.Get("hostname"); !host.IsUndefined() {
hostnameStr := host.String()
if hostnameStr != "" && hostnameStr != "localhost" {
info.Hostname = hostnameStr
}
}
}
func checkFileAndProcess(_ []string) ([]File, error) {
return []File{}, nil
}
func collectSystemInfo(info *Info) {
navigator := js.Global().Get("navigator")
if navigator.IsUndefined() {
return
}
if vendor := navigator.Get("vendor"); !vendor.IsUndefined() {
info.SystemManufacturer = vendor.String()
}
if product := navigator.Get("product"); !product.IsUndefined() {
info.SystemProductName = product.String()
}
if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() {
ua := userAgent.String()
info.Environment = detectEnvironmentFromUA(ua)
}
}
func parseOSFromUserAgent(userAgent string) (string, string) {
if userAgent == "" {
return "", ""
}
switch {
case strings.Contains(userAgent, "Windows NT"):
return parseWindowsVersion(userAgent)
case strings.Contains(userAgent, "Mac OS X"):
return parseMacOSVersion(userAgent)
case strings.Contains(userAgent, "FreeBSD"):
return "FreeBSD", ""
case strings.Contains(userAgent, "OpenBSD"):
return "OpenBSD", ""
case strings.Contains(userAgent, "NetBSD"):
return "NetBSD", ""
case strings.Contains(userAgent, "Linux"):
return parseLinuxVersion(userAgent)
case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"):
return parseiOSVersion(userAgent)
case strings.Contains(userAgent, "CrOS"):
return "ChromeOS", ""
default:
return "", ""
}
}
func parseWindowsVersion(userAgent string) (string, string) {
switch {
case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"):
return "Windows", "10/11"
case strings.Contains(userAgent, "Windows NT 10.0"):
return "Windows", "10"
case strings.Contains(userAgent, "Windows NT 6.3"):
return "Windows", "8.1"
case strings.Contains(userAgent, "Windows NT 6.2"):
return "Windows", "8"
case strings.Contains(userAgent, "Windows NT 6.1"):
return "Windows", "7"
default:
return "Windows", "Unknown"
}
}
func parseMacOSVersion(userAgent string) (string, string) {
idx := strings.Index(userAgent, "Mac OS X ")
if idx == -1 {
return "macOS", "Unknown"
}
versionStart := idx + len("Mac OS X ")
versionEnd := strings.Index(userAgent[versionStart:], ")")
if versionEnd <= 0 {
return "macOS", "Unknown"
}
ver := userAgent[versionStart : versionStart+versionEnd]
ver = strings.ReplaceAll(ver, "_", ".")
return "macOS", ver
}
func parseLinuxVersion(userAgent string) (string, string) {
if strings.Contains(userAgent, "Android") {
return "Android", extractAndroidVersion(userAgent)
}
if strings.Contains(userAgent, "Ubuntu") {
return "Ubuntu", ""
}
return "Linux", ""
}
func parseiOSVersion(userAgent string) (string, string) {
idx := strings.Index(userAgent, "OS ")
if idx == -1 {
return "iOS", "Unknown"
}
versionStart := idx + 3
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd <= 0 {
return "iOS", "Unknown"
}
ver := userAgent[versionStart : versionStart+versionEnd]
ver = strings.ReplaceAll(ver, "_", ".")
return "iOS", ver
}
func extractAndroidVersion(userAgent string) string {
if idx := strings.Index(userAgent, "Android "); idx != -1 {
versionStart := idx + len("Android ")
versionEnd := strings.IndexAny(userAgent[versionStart:], ";)")
if versionEnd > 0 {
return userAgent[versionStart : versionStart+versionEnd]
}
}
return "Unknown"
}
func detectEnvironmentFromUA(_ string) Environment {
return Environment{}
}

245
client/wasm/cmd/main.go Normal file
View File

@@ -0,0 +1,245 @@
//go:build js
package main
import (
"context"
"fmt"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util"
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
defaultLogLevel = "warn"
)
func main() {
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
select {}
}
func startClient(ctx context.Context, nbClient *netbird.Client) error {
log.Info("Starting NetBird client...")
if err := nbClient.Start(ctx); err != nil {
return err
}
log.Info("NetBird client started successfully")
return nil
}
// parseClientOptions extracts NetBird options from JavaScript object
func parseClientOptions(jsOptions js.Value) (netbird.Options, error) {
options := netbird.Options{
DeviceName: "dashboard-client",
LogLevel: defaultLogLevel,
}
if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() {
options.JWTToken = jwtToken.String()
}
if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() {
options.SetupKey = setupKey.String()
}
if privateKey := jsOptions.Get("privateKey"); !privateKey.IsNull() && !privateKey.IsUndefined() {
options.PrivateKey = privateKey.String()
}
if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() {
mgmtURLStr := mgmtURL.String()
if mgmtURLStr != "" {
options.ManagementURL = mgmtURLStr
}
}
if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() {
options.LogLevel = logLevel.String()
}
if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() {
options.DeviceName = deviceName.String()
}
return options, nil
}
// createStartMethod creates the start method for the client
func createStartMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout)
defer cancel()
if err := startClient(ctx, client); err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
resolve.Invoke(js.ValueOf(true))
})
})
}
// createStopMethod creates the stop method for the client
func createStopMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
defer cancel()
if err := client.Stop(ctx); err != nil {
log.Errorf("Error stopping client: %v", err)
reject.Invoke(js.ValueOf(err.Error()))
return
}
log.Info("NetBird client stopped")
resolve.Invoke(js.ValueOf(true))
})
})
}
// createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
}
host := args[0].String()
port := args[1].Int()
username := "root"
if len(args) > 2 && args[2].String() != "" {
username = args[2].String()
}
return createPromise(func(resolve, reject js.Value) {
sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username); err != nil {
reject.Invoke(err.Error())
return
}
if err := sshClient.StartSession(80, 24); err != nil {
if closeErr := sshClient.Close(); closeErr != nil {
log.Errorf("Error closing SSH client: %v", closeErr)
}
reject.Invoke(err.Error())
return
}
jsInterface := ssh.CreateJSInterface(sshClient)
resolve.Invoke(jsInterface)
})
})
}
// createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: request details required")
}
request := args[0]
return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request)
if err != nil {
reject.Invoke(err.Error())
return
}
resolve.Invoke(response)
})
})
}
// createRDPProxyMethod creates the RDP proxy method
func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
})
}
// createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]
go handler(resolve, reject)
return nil
}))
}
// createClientObject wraps the NetBird client in a JavaScript object
func createClientObject(client *netbird.Client) js.Value {
obj := make(map[string]interface{})
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
return js.ValueOf(obj)
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(this js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]
if len(args) < 1 {
reject.Invoke(js.ValueOf("Options object required"))
return nil
}
go func() {
options, err := parseClientOptions(args[0])
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
log.Warnf("Failed to initialize logging: %v", err)
}
log.Infof("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s",
options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL)
client, err := netbird.New(options)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err)))
return
}
clientObj := createClientObject(client)
log.Info("NetBird client created successfully")
resolve.Invoke(clientObj)
}()
return nil
}))
}

View File

@@ -0,0 +1,100 @@
//go:build js
package http
import (
"fmt"
"io"
log "github.com/sirupsen/logrus"
"net/http"
"strings"
"syscall/js"
"time"
netbird "github.com/netbirdio/netbird/client/embed"
)
const (
httpTimeout = 30 * time.Second
maxResponseSize = 1024 * 1024 // 1MB
)
// performRequest executes an HTTP request through NetBird and returns the response and body
func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) {
httpClient := nbClient.NewHTTPClient()
httpClient.Timeout = httpTimeout
req, err := http.NewRequest(method, url, strings.NewReader(string(body)))
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("request failed: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Errorf("failed to close response body: %v", err)
}
}()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
return resp, respBody, nil
}
// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object
func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) {
url := request.Get("url").String()
if url == "" {
return js.Undefined(), fmt.Errorf("URL is required")
}
method := "GET"
if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() {
method = strings.ToUpper(methodVal.String())
}
var requestBody []byte
if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() {
requestBody = []byte(bodyVal.String())
}
requestHeaders := make(map[string]string)
if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject {
headerKeys := js.Global().Get("Object").Call("keys", headersVal)
for i := 0; i < headerKeys.Length(); i++ {
key := headerKeys.Index(i).String()
value := headersVal.Get(key).String()
requestHeaders[key] = value
}
}
resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody)
if err != nil {
return js.Undefined(), err
}
result := js.Global().Get("Object").New()
result.Set("status", resp.StatusCode)
result.Set("statusText", resp.Status)
result.Set("body", string(body))
headers := js.Global().Get("Object").New()
for key, values := range resp.Header {
if len(values) > 0 {
headers.Set(strings.ToLower(key), values[0])
}
}
result.Set("headers", headers)
return result, nil
}

View File

@@ -0,0 +1,96 @@
//go:build js
package rdp
import (
"crypto/tls"
"crypto/x509"
"fmt"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
certValidationTimeout = 60 * time.Second
)
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
return false, fmt.Errorf("certificate validation handler not configured")
}
certInfo := js.Global().Get("Object").New()
certInfo.Set("ServerAddr", conn.destination)
certArray := js.Global().Get("Array").New()
for i, certBytes := range certChain {
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
js.CopyBytesToJS(uint8Array, certBytes)
certArray.SetIndex(i, uint8Array)
}
certInfo.Set("ServerCertChain", certArray)
if len(certChain) > 0 {
cert, err := x509.ParseCertificate(certChain[0])
if err == nil {
info := js.Global().Get("Object").New()
info.Set("subject", cert.Subject.String())
info.Set("issuer", cert.Issuer.String())
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
info.Set("serialNumber", cert.SerialNumber.String())
certInfo.Set("CertificateInfo", info)
}
}
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool)
errorChan := make(chan error)
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
result := args[0].Bool()
resultChan <- result
return nil
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
errorChan <- fmt.Errorf("certificate validation failed")
return nil
}))
select {
case result := <-resultChan:
if result {
log.Info("Certificate accepted by user")
} else {
log.Info("Certificate rejected by user")
}
return result, nil
case err := <-errorChan:
return false, err
case <-time.After(certValidationTimeout):
return false, fmt.Errorf("certificate validation timeout")
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
for _, cert := range cs.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
accepted, err := p.validateCertificateWithJS(conn, certChain)
if err != nil {
return err
}
if !accepted {
return fmt.Errorf("certificate rejected by user")
}
return nil
},
}
}

View File

@@ -0,0 +1,271 @@
//go:build js
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"fmt"
"io"
"net"
"sync"
"syscall/js"
log "github.com/sirupsen/logrus"
)
const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
activeConnections map[string]*proxyConnection
destinations map[string]string
mu sync.Mutex
}
type proxyConnection struct {
id string
destination string
rdpConn net.Conn
tlsConn *tls.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
}
// NewRDCleanPathProxy creates a new RDCleanPath proxy
func NewRDCleanPathProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *RDCleanPathProxy {
return &RDCleanPathProxy{
nbClient: client,
activeConnections: make(map[string]*proxyConnection),
}
}
// CreateProxy creates a new proxy endpoint for the given destination
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := fmt.Sprintf("%s:%s", hostname, port)
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
p.mu.Lock()
if p.destinations == nil {
p.destinations = make(map[string]string)
}
p.destinations[proxyID] = destination
p.mu.Unlock()
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
ws := args[0]
p.HandleWebSocketConnection(ws, proxyID)
return nil
}))
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL)
}()
return nil
}))
}
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
p.mu.Lock()
destination := p.destinations[proxyID]
p.mu.Unlock()
if destination == "" {
log.Errorf("No destination found for proxy ID: %s", proxyID)
return
}
ctx, cancel := context.WithCancel(context.Background())
// Don't defer cancel here - it will be called by cleanupConnection
conn := &proxyConnection{
id: proxyID,
destination: destination,
wsHandlers: ws,
ctx: ctx,
cancel: cancel,
}
p.mu.Lock()
p.activeConnections[proxyID] = conn
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
}
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
}))
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript")
conn.cancel()
return nil
}))
}
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
return
}
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
if conn.rdpConn != nil || conn.tlsConn != nil {
p.forwardToRDP(conn, bytes)
return
}
var pdu RDCleanPathPDU
_, err := asn1.Unmarshal(bytes, &pdu)
if err != nil {
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
n := len(bytes)
if n > 20 {
n = 20
}
log.Warnf("First %d bytes: %x", n, bytes[:n])
if len(bytes) > 0 && bytes[0] == 0x03 {
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
go p.handleDirectRDP(conn, bytes)
return
}
return
}
go p.processRDCleanPathPDU(conn, pdu)
}
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
var writer io.Writer
var connType string
if conn.tlsConn != nil {
writer = conn.tlsConn
connType = "TLS"
} else if conn.rdpConn != nil {
writer = conn.rdpConn
connType = "TCP"
} else {
log.Error("No RDP connection available")
return
}
if _, err := writer.Write(bytes); err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
}
}
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
defer p.cleanupConnection(conn)
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
return
}
conn.rdpConn = rdpConn
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
return
}
response := make([]byte, 1024)
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
return
}
p.sendToWebSocket(conn, response[:n])
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
}
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
conn.rdpConn = nil
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
p.mu.Unlock()
}
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
} else if conn.wsHandlers.Get("send").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}

View File

@@ -0,0 +1,251 @@
//go:build js
package rdp
import (
"crypto/tls"
"encoding/asn1"
"io"
"syscall/js"
log "github.com/sirupsen/logrus"
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version")
return
}
destination := conn.destination
if pdu.Destination != "" {
destination = pdu.Destination
}
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed")
p.cleanupConnection(conn)
return
}
conn.rdpConn = rdpConn
// RDP always starts with X.224 negotiation, then determines if TLS is needed
// Modern RDP (since Windows Vista/2008) typically requires TLS
// The X.224 Connection Confirm response will indicate if TLS is required
// For now, we'll attempt TLS for all connections as it's the modern default
p.setupTLSConnection(conn, pdu)
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
tlsConfig := p.getTLSConfigWithValidation(conn)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed")
return
}
log.Info("TLS handshake successful")
// Certificate validation happens during handshake via VerifyConnection callback
var certChain [][]byte
connState := tlsConn.ConnectionState()
if len(connState.PeerCertificates) > 0 {
for _, cert := range connState.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
ServerCertChain: certChain,
}
if len(x224Response) > 0 {
responsePDU.X224ConnectionPDU = x224Response
}
p.sendRDCleanPathPDU(conn, responsePDU)
log.Debug("Starting TLS forwarding")
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
<-conn.ctx.Done()
log.Debug("TLS connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
return
}
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
if len(args) < 1 {
errChan <- io.EOF
return nil
}
data := args[0]
if data.InstanceOf(js.Global().Get("Uint8Array")) {
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
msgChan <- bytes
}
return nil
})
defer handler.Release()
conn.wsHandlers.Set("onceGoMessage", handler)
select {
case msg := <-msgChan:
return msg, nil
case err := <-errChan:
return nil, err
case <-conn.ctx.Done():
return nil, conn.ctx.Err()
}
}
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
for {
if conn.ctx.Err() != nil {
return
}
msg, err := p.readWebSocketMessage(conn)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from WebSocket: %v", err)
}
return
}
_, err = dst.Write(msg)
if err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
return
}
}
}
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
buffer := make([]byte, 32*1024)
for {
if conn.ctx.Err() != nil {
return
}
n, err := src.Read(buffer)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from %s: %v", connType, err)
}
return
}
if n > 0 {
p.sendToWebSocket(conn, buffer[:n])
}
}
}

View File

@@ -0,0 +1,213 @@
//go:build js
package ssh
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
netbird "github.com/netbirdio/netbird/client/embed"
)
const (
sshDialTimeout = 30 * time.Second
)
func closeWithLog(c io.Closer, resource string) {
if c != nil {
if err := c.Close(); err != nil {
logrus.Debugf("Failed to close %s: %v", resource, err)
}
}
}
type Client struct {
nbClient *netbird.Client
sshClient *ssh.Client
session *ssh.Session
stdin io.WriteCloser
stdout io.Reader
stderr io.Reader
mu sync.RWMutex
}
// NewClient creates a new SSH client
func NewClient(nbClient *netbird.Client) *Client {
return &Client{
nbClient: nbClient,
}
}
// Connect establishes an SSH connection through NetBird network
func (c *Client) Connect(host string, port int, username string) error {
addr := fmt.Sprintf("%s:%d", host, port)
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
var authMethods []ssh.AuthMethod
nbConfig, err := c.nbClient.GetConfig()
if err != nil {
return fmt.Errorf("get NetBird config: %w", err)
}
if nbConfig.SSHKey == "" {
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
}
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
if err != nil {
return fmt.Errorf("parse NetBird SSH private key: %w", err)
}
pubKey := signer.PublicKey()
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
authMethods = append(authMethods, ssh.PublicKeys(signer))
config := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: sshDialTimeout,
}
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
defer cancel()
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
closeWithLog(conn, "connection after handshake error")
return fmt.Errorf("SSH handshake: %w", err)
}
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
logrus.Infof("SSH: Connected to %s", addr)
return nil
}
// StartSession starts an SSH session with PTY
func (c *Client) StartSession(cols, rows int) error {
if c.sshClient == nil {
return fmt.Errorf("SSH client not connected")
}
session, err := c.sshClient.NewSession()
if err != nil {
return fmt.Errorf("create session: %w", err)
}
c.mu.Lock()
defer c.mu.Unlock()
c.session = session
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
ssh.VINTR: 3,
ssh.VQUIT: 28,
ssh.VERASE: 127,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
closeWithLog(session, "session after PTY error")
return fmt.Errorf("PTY request: %w", err)
}
c.stdin, err = session.StdinPipe()
if err != nil {
closeWithLog(session, "session after stdin error")
return fmt.Errorf("get stdin: %w", err)
}
c.stdout, err = session.StdoutPipe()
if err != nil {
closeWithLog(session, "session after stdout error")
return fmt.Errorf("get stdout: %w", err)
}
c.stderr, err = session.StderrPipe()
if err != nil {
closeWithLog(session, "session after stderr error")
return fmt.Errorf("get stderr: %w", err)
}
if err := session.Shell(); err != nil {
closeWithLog(session, "session after shell error")
return fmt.Errorf("start shell: %w", err)
}
logrus.Info("SSH: Session started with PTY")
return nil
}
// Write sends data to the SSH session
func (c *Client) Write(data []byte) (int, error) {
c.mu.RLock()
stdin := c.stdin
c.mu.RUnlock()
if stdin == nil {
return 0, fmt.Errorf("SSH session not started")
}
return stdin.Write(data)
}
// Read reads data from the SSH session
func (c *Client) Read(buffer []byte) (int, error) {
c.mu.RLock()
stdout := c.stdout
c.mu.RUnlock()
if stdout == nil {
return 0, fmt.Errorf("SSH session not started")
}
return stdout.Read(buffer)
}
// Resize updates the terminal size
func (c *Client) Resize(cols, rows int) error {
c.mu.RLock()
session := c.session
c.mu.RUnlock()
if session == nil {
return fmt.Errorf("SSH session not started")
}
return session.WindowChange(rows, cols)
}
// Close closes the SSH connection
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.session != nil {
closeWithLog(c.session, "SSH session")
c.session = nil
}
if c.stdin != nil {
closeWithLog(c.stdin, "stdin")
c.stdin = nil
}
c.stdout = nil
c.stderr = nil
if c.sshClient != nil {
err := c.sshClient.Close()
c.sshClient = nil
return err
}
return nil
}

View File

@@ -0,0 +1,78 @@
//go:build js
package ssh
import (
"io"
"syscall/js"
"github.com/sirupsen/logrus"
)
// CreateJSInterface creates a JavaScript interface for the SSH client
func CreateJSInterface(client *Client) js.Value {
jsInterface := js.Global().Get("Object").Call("create", js.Null())
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf(false)
}
data := args[0]
var bytes []byte
if data.Type() == js.TypeString {
bytes = []byte(data.String())
} else {
uint8Array := js.Global().Get("Uint8Array").New(data)
length := uint8Array.Get("length").Int()
bytes = make([]byte, length)
js.CopyBytesToGo(bytes, uint8Array)
}
_, err := client.Write(bytes)
return js.ValueOf(err == nil)
}))
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf(false)
}
cols := args[0].Int()
rows := args[1].Int()
err := client.Resize(cols, rows)
return js.ValueOf(err == nil)
}))
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
client.Close()
return js.Undefined()
}))
go readLoop(client, jsInterface)
return jsInterface
}
func readLoop(client *Client, jsInterface js.Value) {
buffer := make([]byte, 4096)
for {
n, err := client.Read(buffer)
if err != nil {
if err != io.EOF {
logrus.Debugf("SSH read error: %v", err)
}
if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() {
onclose.Invoke()
}
client.Close()
return
}
if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() {
uint8Array := js.Global().Get("Uint8Array").New(n)
js.CopyBytesToJS(uint8Array, buffer[:n])
ondata.Invoke(uint8Array)
}
}
}

View File

@@ -0,0 +1,50 @@
//go:build js
package ssh
import (
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
keyStr := string(keyPEM)
if !strings.Contains(keyStr, "-----BEGIN") {
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
}
signer, err := ssh.ParsePrivateKey(keyPEM)
if err == nil {
return signer, nil
}
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
block, _ := pem.Decode(keyPEM)
if block == nil {
keyPreview := string(keyPEM)
if len(keyPreview) > 100 {
keyPreview = keyPreview[:100]
}
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(rsaKey)
}
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(ecKey)
}
return nil, fmt.Errorf("parse private key: %w", err)
}
return ssh.NewSignerFromKey(key)
}

View File

@@ -1,3 +1,5 @@
//go:build !js
package encryption
import (

View File

@@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
return nil, fmt.Errorf("parsing url: %w", err)
}
var opts []grpc.DialOption
if parsedURL.Scheme == "https" {
tlsEnabled := parsedURL.Scheme == "https"
if tlsEnabled {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
}
opts = append(opts,
nbgrpc.WithCustomDialer(),
nbgrpc.WithCustomDialer(tlsEnabled),
grpc.WithIdleTimeout(interval*2),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

2
go.mod
View File

@@ -37,7 +37,7 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.12
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/lib/v4 v4.2.0

4
go.sum
View File

@@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=

View File

@@ -10,6 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
)
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
@@ -56,8 +58,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
})
}
func (s *BaseServer) EphemeralManager() *server.EphemeralManager {
return Create(s, func() *server.EphemeralManager {
return server.NewEphemeralManager(s.Store(), s.AccountManager())
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
})
}

View File

@@ -65,6 +65,10 @@ func (s *BaseServer) AccountManager() account.Manager {
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetEphemeralManager(s.EphemeralManager())
})
return accountManager
})
}

View File

@@ -6,12 +6,14 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
@@ -22,6 +24,8 @@ import (
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
"github.com/netbirdio/netbird/version"
)
@@ -92,12 +96,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.PeersManager()
s.GeoLocationManager()
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err)
@@ -147,7 +145,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler())
rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
@@ -176,6 +174,12 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
}
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
@@ -247,13 +251,17 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
return util.DirectWriteJson(ctx, path, config)
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
switch {
case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")):
gRPCHandler.ServeHTTP(writer, request)
} else {
case request.URL.Path == wsproxy.ProxyPath:
wsProxy.Handler().ServeHTTP(writer, request)
default:
httpHandler.ServeHTTP(writer, request)
}
})

View File

@@ -35,6 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -74,6 +75,7 @@ type DefaultAccountManager struct {
ctx context.Context
eventStore activity.Store
geo geolocation.Geolocation
ephemeralManager ephemeral.Manager
requestBuffer *AccountRequestBuffer
@@ -261,6 +263,10 @@ func BuildManager(
return am, nil
}
func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
am.ephemeralManager = em
}
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")

View File

@@ -12,6 +12,7 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -56,7 +57,7 @@ type Manager interface {
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -125,5 +126,6 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
}

View File

@@ -66,7 +66,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
setupKey = key.Key
}
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -1048,10 +1048,10 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -1112,10 +1112,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
return
@@ -1429,10 +1429,10 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -1805,11 +1805,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
}, false)
require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1861,11 +1861,11 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
}, false)
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
@@ -1904,11 +1904,11 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
}, false)
require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -2952,14 +2952,14 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
Status: &nbpeer.PeerStatus{
Connected: true,
LastSeen: time.Now().UTC(),
},
})
}, false)
if err != nil {
t.Fatalf("expecting peer to be added, got failure %v", err)
}
@@ -3552,16 +3552,16 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
}, false)
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
require.NoError(t, err, "unable to add peer2")
t.Run("update peer IP successfully", func(t *testing.T) {

View File

@@ -281,11 +281,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return nil, err
}
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
if err != nil {
return nil, err
}

View File

@@ -22,6 +22,7 @@ import (
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -55,7 +56,7 @@ type GRPCServer struct {
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
ephemeralManager ephemeral.Manager
peerLocks sync.Map
authManager auth.Manager
@@ -73,7 +74,7 @@ func NewServer(
peersUpdateManager *PeersUpdateManager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager,
ephemeralManager ephemeral.Manager,
authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) {

View File

@@ -32,6 +32,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
}
// NewHandler creates a new peers Handler
@@ -318,6 +319,88 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
var req api.PeerTemporaryAccessRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
newPeer := &nbpeer.Peer{}
newPeer.FromAPITemporaryAccessRequest(&req)
targetPeer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
for _, rule := range req.Rules {
protocol, portRange, err := types.ParseRuleString(rule)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policy := &types.Policy{
AccountID: userAuth.AccountId,
Description: "Temporary access policy for peer " + peer.Name,
Name: "Temporary access policy for peer " + peer.Name,
Enabled: true,
Rules: []*types.PolicyRule{{
Name: "Temporary access rule",
Description: "Temporary access rule",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
SourceResource: types.Resource{
Type: types.ResourceTypePeer,
ID: peer.ID,
},
DestinationResource: types.Resource{
Type: types.ResourceTypePeer,
ID: targetPeer.ID,
},
Bidirectional: false,
Protocol: protocol,
PortRanges: []types.RulePortRange{portRange},
}},
}
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
}
resp := &api.PeerTemporaryAccessResponse{
Id: peer.ID,
Name: peer.Name,
Rules: req.Rules,
}
util.WriteJSONObject(r.Context(), w, resp)
}
func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -460,7 +461,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := NewEphemeralManager(store, accountManager)
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil {
return nil, nil, "", cleanup, err

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -228,7 +229,7 @@ func startServer(
peersUpdateManager,
secretsManager,
nil,
nil,
&manager.EphemeralManager{},
nil,
server.MockIntegratedValidator{},
)

View File

@@ -15,6 +15,7 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -41,7 +42,7 @@ type MockAccountManager struct {
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
@@ -351,12 +352,14 @@ func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string
// AddPeer mock implementation of AddPeer from server.AccountManager interface
func (am *MockAccountManager) AddPeer(
ctx context.Context,
accountID string,
setupKey string,
userId string,
peer *nbpeer.Peer,
temporary bool,
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if am.AddPeerFunc != nil {
return am.AddPeerFunc(ctx, setupKey, userId, peer)
return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
}
@@ -972,6 +975,11 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}
// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface
func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) {
// Mock implementation - does nothing
}
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
if am.AllowSyncFunc != nil {
return am.AllowSyncFunc(key, hash)

View File

@@ -876,11 +876,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
if err != nil {
return nil, err
}

View File

@@ -132,7 +132,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
res := nbtypes.Resource{
ID: resource.ID,
Type: resource.Type.String(),
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
@@ -265,7 +265,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) {
res := nbtypes.Resource{
ID: newResource.ID,
Type: newResource.Type.String(),
Type: nbtypes.ResourceType(newResource.Type.String()),
}
oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID)

View File

@@ -450,7 +450,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if setupKey == "" && userID == "" {
// no auth method provided => reject access
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
@@ -482,8 +482,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
var ephemeral bool
var groupsToAdd []string
var allowExtraDNSLabels bool
var accountID string
var isEphemeral bool
if addedByUser {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
@@ -492,10 +490,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if user.PendingApproval {
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
}
groupsToAdd = user.AutoGroups
if temporary {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create)
if err != nil {
return nil, nil, nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, nil, nil, status.NewPermissionDeniedError()
}
} else {
accountID = user.AccountID
groupsToAdd = user.AutoGroups
}
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
accountID = user.AccountID
} else {
// Validate the setup key
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
@@ -516,13 +525,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
setupKeyName = sk.Name
allowExtraDNSLabels = sk.AllowExtraDNSLabels
accountID = sk.AccountID
isEphemeral = sk.Ephemeral
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
}
}
opEvent.AccountID = accountID
if temporary {
ephemeral = true
}
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
@@ -549,10 +561,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
SSHKey: peer.SSHKey,
LastLogin: &registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
LoginExpirationEnabled: addedByUser && !temporary,
Ephemeral: ephemeral,
Location: peer.Location,
InactivityExpirationEnabled: addedByUser,
InactivityExpirationEnabled: addedByUser && !temporary,
ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels,
}
@@ -588,7 +600,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var freeLabel string
if isEphemeral || attempt > 1 {
if ephemeral || attempt > 1 {
freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
@@ -622,6 +634,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if temporary {
// we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually
am.ephemeralManager.OnPeerDisconnected(ctx, newPeer)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -790,7 +807,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
ExtraDNSLabels: login.ExtraDNSLabels,
}
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
return am.AddPeer(ctx, "", login.SetupKey, login.UserID, newPeer, false)
}
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
@@ -877,6 +894,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
updateRemotePeers = true
}
if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 {
@@ -1540,6 +1558,26 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err
}
peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, err
}
for _, rule := range peerPolicyRules {
policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID)
if err != nil {
return nil, err
}
err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID)
if err != nil {
return nil, err
}
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
})
}
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Peer represents a machine connected to the network.
@@ -334,6 +335,17 @@ func (p *Peer) UpdateLastLogin() *Peer {
return p
}
func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) {
p.Ephemeral = true
p.Name = a.Name
p.Key = a.WgPubKey
p.Meta = PeerSystemMeta{
Hostname: a.Name,
GoOS: "js",
OS: "js",
}
}
func (f Flags) isEqual(other Flags) bool {
return f.RosenpassEnabled == other.RosenpassEnabled &&
f.RosenpassPermissive == other.RosenpassPermissive &&

View File

@@ -193,10 +193,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -207,10 +207,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
t.Fatal(err)
return
}
_, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
@@ -266,10 +266,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -280,10 +280,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
t.Fatal(err)
return
}
peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -442,10 +442,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -456,10 +456,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err)
return
}
_, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
@@ -514,10 +514,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -530,10 +530,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// the second peer added with a setup key
peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
if err != nil {
t.Fatal(err)
return
@@ -702,19 +702,19 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
}
_, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
}, false)
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
@@ -1300,7 +1300,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
},
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false)
require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
@@ -1422,7 +1422,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), tc.existingSetupKeyID, "", currentPeer)
addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false)
if tc.expectAddPeerError {
require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID)
@@ -1523,7 +1523,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
_, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false)
require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
@@ -1658,7 +1658,7 @@ func Test_LoginPeer(t *testing.T) {
if sk.AllowExtraDNSLabels {
currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels
}
_, _, _, err = am.AddPeer(context.Background(), tc.setupKey, "", currentPeer)
_, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false)
require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey)
loginInput := types.PeerLogin{
@@ -1797,10 +1797,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
require.NoError(t, err)
select {
@@ -1918,11 +1918,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
require.NoError(t, err)
select {
@@ -1982,11 +1982,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
require.NoError(t, err)
select {
@@ -2037,11 +2037,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{
peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
require.NoError(t, err)
select {
@@ -2208,7 +2208,7 @@ func Test_AddPeer(t *testing.T) {
<-start
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
@@ -2416,7 +2416,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer)
_, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
}
@@ -2451,7 +2451,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer)
_, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false)
require.NoError(t, err, "Regular user should be able to add peers")
}
@@ -2494,7 +2494,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer)
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false)
require.NoError(t, err)
// Now set the user back to pending approval after peer was created
@@ -2550,7 +2550,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer)
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false)
require.NoError(t, err)
// Try to login with regular user

View File

@@ -0,0 +1,14 @@
package ephemeral
import (
"context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type Manager interface {
LoadInitialPeers(ctx context.Context)
Stop()
OnPeerConnected(ctx context.Context, peer *nbpeer.Peer)
OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer)
}

View File

@@ -1,4 +1,4 @@
package server
package manager
import (
"context"

View File

@@ -1,4 +1,4 @@
package server
package manager
import (
"context"
@@ -7,12 +7,15 @@ import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
type MockStore struct {
@@ -223,3 +226,57 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int)
store.account.Peers[p.ID] = p
}
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
owner := types.NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner
dnsSettings := types.DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
log.WithContext(ctx).Debugf("created new account %s", accountID)
acc := &types.Account{
Id: accountID,
CreatedAt: time.Now().UTC(),
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userID,
Domain: domain,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
},
Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
},
}
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
}

View File

@@ -151,6 +151,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
@@ -161,6 +167,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
}
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}

View File

@@ -2037,6 +2037,25 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string)
})
}
func (s *SqlStore) GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, resourceID string) ([]*types.PolicyRule, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var policyRules []*types.PolicyRule
resourceIDPattern := `%"ID":"` + resourceID + `"%`
result := tx.Where("source_resource LIKE ? OR destination_resource LIKE ?", resourceIDPattern, resourceIDPattern).
Find(&policyRules)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get policy rules for resource id from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get policy rules for resource id from store")
}
return policyRules, nil
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
tx := s.db

View File

@@ -202,6 +202,7 @@ type Store interface {
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
}
const (

View File

@@ -1001,8 +1001,20 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
var sourcePeers, destinationPeers []*nbpeer.Peer
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID)
} else {
sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID)
} else {
destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
}
if rule.Bidirectional {
if peerInSources {
@@ -1124,6 +1136,15 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
return filteredPeers, peerInGroups
}
func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
peer := a.GetPeer(resource.ID)
if peer == nil {
return []*nbpeer.Peer{}, false
}
return []*nbpeer.Peer{peer}, resource.ID == peerID
}
// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
@@ -1379,7 +1400,12 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] {
peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}

View File

@@ -1,5 +1,12 @@
package types
import (
"errors"
"fmt"
"strconv"
"strings"
)
const (
// PolicyTrafficActionAccept indicates that the traffic is accepted
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
@@ -134,3 +141,83 @@ func (p *Policy) SourceGroups() []string {
return groupIDs
}
func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) {
rule = strings.TrimSpace(strings.ToLower(rule))
if rule == "all" {
return PolicyRuleProtocolALL, RulePortRange{}, nil
}
if rule == "icmp" {
return PolicyRuleProtocolICMP, RulePortRange{}, nil
}
split := strings.Split(rule, "/")
if len(split) != 2 {
return "", RulePortRange{}, errors.New("invalid rule format: expected protocol/port or protocol/port-range")
}
protoStr := strings.TrimSpace(split[0])
portStr := strings.TrimSpace(split[1])
var protocol PolicyRuleProtocolType
switch protoStr {
case "tcp":
protocol = PolicyRuleProtocolTCP
case "udp":
protocol = PolicyRuleProtocolUDP
case "icmp":
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
default:
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
}
portRange, err := parsePortRange(portStr)
if err != nil {
return "", RulePortRange{}, err
}
return protocol, portRange, nil
}
func parsePortRange(portStr string) (RulePortRange, error) {
if strings.Contains(portStr, "-") {
rangeParts := strings.Split(portStr, "-")
if len(rangeParts) != 2 {
return RulePortRange{}, fmt.Errorf("invalid port range %q", portStr)
}
start, err := parsePort(strings.TrimSpace(rangeParts[0]))
if err != nil {
return RulePortRange{}, err
}
end, err := parsePort(strings.TrimSpace(rangeParts[1]))
if err != nil {
return RulePortRange{}, err
}
if start > end {
return RulePortRange{}, fmt.Errorf("invalid port range: start %d > end %d", start, end)
}
return RulePortRange{Start: uint16(start), End: uint16(end)}, nil
}
p, err := parsePort(portStr)
if err != nil {
return RulePortRange{}, err
}
return RulePortRange{Start: uint16(p), End: uint16(p)}, nil
}
func parsePort(portStr string) (int, error) {
if portStr == "" {
return 0, errors.New("empty port")
}
p, err := strconv.Atoi(portStr)
if err != nil {
return 0, fmt.Errorf("invalid port %q: %w", portStr, err)
}
if p < 1 || p > 65535 {
return 0, fmt.Errorf("port out of range (165535): %d", p)
}
return p, nil
}

View File

@@ -4,9 +4,18 @@ import (
"github.com/netbirdio/netbird/shared/management/http/api"
)
type ResourceType string
const (
ResourceTypePeer ResourceType = "peer"
ResourceTypeDomain ResourceType = "domain"
ResourceTypeHost ResourceType = "host"
ResourceTypeSubnet ResourceType = "subnet"
)
type Resource struct {
ID string
Type string
Type ResourceType
}
func (r *Resource) ToAPIResponse() *api.Resource {
@@ -26,5 +35,5 @@ func (r *Resource) FromAPIRequest(req *api.Resource) {
}
r.ID = req.Id
r.Type = string(req.Type)
r.Type = ResourceType(req.Type)
}

View File

@@ -1439,10 +1439,10 @@ func TestUserAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
}, false)
require.NoError(t, err)
// updating user with linked peers should update account peers and send peer update

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
@@ -27,6 +28,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -117,7 +119,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
groupsManager := groups.NewManagerMock()
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -507,6 +507,48 @@ components:
- serial_number
- extra_dns_labels
- ephemeral
PeerTemporaryAccessRequest:
type: object
properties:
name:
description: Peer's hostname
type: string
example: temp-host-1
wg_pub_key:
description: Peer's WireGuard public key
type: string
example: "n0r3pL4c3h0ld3rK3y=="
rules:
description: List of temporary access rules
type: array
items:
type: string
example: "tcp/80"
required:
- name
- wg_pub_key
- rules
PeerTemporaryAccessResponse:
type: object
properties:
name:
description: Peer's hostname
type: string
example: temp-host-1
id:
description: Peer ID
type: string
example: chacbco6lnnbn6cg5s90
rules:
description: List of temporary access rules
type: array
items:
type: string
example: "tcp/80"
required:
- name
- id
- rules
AccessiblePeer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
@@ -1404,7 +1446,8 @@ components:
allOf:
- $ref: '#/components/schemas/NetworkResourceType'
- type: string
example: host
enum: ["peer"]
example: peer
NetworkRequest:
type: object
properties:
@@ -2793,6 +2836,42 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/temporary-access:
post:
summary: Create a Temporary Access Peer
description: Creates a temporary access peer that can be used to access this peer and this peer only. The temporary access peer and its access policies will be automatically deleted after it disconnects.
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
requestBody:
description: Temporary Access Peer create request
content:
'application/json':
schema:
$ref: '#/components/schemas/PeerTemporaryAccessRequest'
responses:
'200':
description: Temporary Access Peer response
content:
application/json:
schema:
$ref: '#/components/schemas/PeerTemporaryAccessResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/ingress/ports:
get:
x-cloud-only: true

View File

@@ -168,6 +168,7 @@ const (
const (
ResourceTypeDomain ResourceType = "domain"
ResourceTypeHost ResourceType = "host"
ResourceTypePeer ResourceType = "peer"
ResourceTypeSubnet ResourceType = "subnet"
)
@@ -1221,6 +1222,30 @@ type PeerRequest struct {
SshEnabled bool `json:"ssh_enabled"`
}
// PeerTemporaryAccessRequest defines model for PeerTemporaryAccessRequest.
type PeerTemporaryAccessRequest struct {
// Name Peer's hostname
Name string `json:"name"`
// Rules List of temporary access rules
Rules []string `json:"rules"`
// WgPubKey Peer's WireGuard public key
WgPubKey string `json:"wg_pub_key"`
}
// PeerTemporaryAccessResponse defines model for PeerTemporaryAccessResponse.
type PeerTemporaryAccessResponse struct {
// Id Peer ID
Id string `json:"id"`
// Name Peer's hostname
Name string `json:"name"`
// Rules List of temporary access rules
Rules []string `json:"rules"`
}
// PersonalAccessToken defines model for PersonalAccessToken.
type PersonalAccessToken struct {
// CreatedAt Date the token was created
@@ -1949,6 +1974,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques
// PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType.
type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest
// PostApiPeersPeerIdTemporaryAccessJSONRequestBody defines body for PostApiPeersPeerIdTemporaryAccess for application/json ContentType.
type PostApiPeersPeerIdTemporaryAccessJSONRequestBody = PeerTemporaryAccessRequest
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
type PostApiPoliciesJSONRequestBody = PolicyUpdate

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v4.24.3
// protoc v6.32.0
// source: management.proto
package proto

View File

@@ -9,11 +9,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
)
@@ -296,14 +293,7 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
// Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues
var dialers []dialer.DialeFn
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
dialers = []dialer.DialeFn{ws.Dialer{}}
} else {
dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
}
dialers := c.getDialers()
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
conn, err := rd.Dial()

View File

@@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return 0, err
return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b)
}
func (c *Conn) RemoteAddr() net.Addr {

View File

@@ -0,0 +1,11 @@
//go:build !js
package ws
import "github.com/coder/websocket"
func createDialOptions() *websocket.DialOptions {
return &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
}

View File

@@ -0,0 +1,10 @@
//go:build js
package ws
import "github.com/coder/websocket"
func createDialOptions() *websocket.DialOptions {
// WASM version doesn't support HTTPClient
return &websocket.DialOptions{}
}

View File

@@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return nil, err
}
opts := &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
opts := createDialOptions()
parsedURL, err := url.Parse(wsURL)
if err != nil {

View File

@@ -0,0 +1,19 @@
//go:build !js
package client
import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// getDialers returns the list of dialers to use for connecting to the relay server.
func (c *Client) getDialers() []dialer.DialeFn {
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
return []dialer.DialeFn{ws.Dialer{}}
}
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
}

Some files were not shown because too many files have changed in this diff Show More