mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 07:33:52 -04:00
Compare commits
3 Commits
feature/st
...
fix/up-seq
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55d8b4e42c | ||
|
|
9b8f7d75b3 | ||
|
|
57f3af57f4 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.23"
|
||||
SIGN_PIPE_VER: "v0.0.22"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
67
.github/workflows/wasm-build-validation.yml
vendored
67
.github/workflows/wasm-build-validation.yml
vendored
@@ -1,67 +0,0 @@
|
||||
name: Wasm
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js_lint:
|
||||
name: "JS / Lint"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
skip-cache: true
|
||||
skip-pkg-cache: true
|
||||
skip-build-cache: true
|
||||
- name: Run golangci-lint for WASM
|
||||
run: |
|
||||
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||
continue-on-error: true
|
||||
|
||||
js_build:
|
||||
name: "JS / Build"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
- name: Check Wasm build size
|
||||
run: |
|
||||
echo "Wasm build size:"
|
||||
ls -lh netbird.wasm
|
||||
|
||||
SIZE=$(stat -c%s netbird.wasm)
|
||||
SIZE_MB=$((SIZE / 1024 / 1024))
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 52428800 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
0
.gitmodules
vendored
0
.gitmodules
vendored
@@ -2,18 +2,6 @@ version: 2
|
||||
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird-wasm
|
||||
dir: client/wasm/cmd
|
||||
binary: netbird
|
||||
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
|
||||
goos:
|
||||
- js
|
||||
goarch:
|
||||
- wasm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird
|
||||
dir: client
|
||||
binary: netbird
|
||||
@@ -127,11 +115,6 @@ archives:
|
||||
- builds:
|
||||
- netbird
|
||||
- netbird-static
|
||||
- id: netbird-wasm
|
||||
builds:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||
format: binary
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
@@ -53,7 +52,7 @@
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
|
||||
@@ -18,7 +18,7 @@ ENV \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import "context"
|
||||
|
||||
// SetupDebugHandler is a no-op for WASM
|
||||
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
|
||||
// Debug handler not needed for WASM
|
||||
}
|
||||
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -116,7 +114,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -230,9 +230,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||
WaitForReady: func() *bool { b := true; return &b }(),
|
||||
})
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{WaitForConnectingShift: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
@@ -23,29 +23,23 @@ import (
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
// Client manages a netbird embedded client instance
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *profilemanager.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
// Options configures a new Client
|
||||
type Options struct {
|
||||
// DeviceName is this peer's name in the network
|
||||
DeviceName string
|
||||
// SetupKey is used for authentication
|
||||
SetupKey string
|
||||
// JWTToken is used for JWT-based authentication
|
||||
JWTToken string
|
||||
// PrivateKey is used for direct private key authentication
|
||||
PrivateKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
@@ -64,35 +58,8 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
func (opts *Options) validateCredentials() error {
|
||||
credentialsProvided := 0
|
||||
if opts.SetupKey != "" {
|
||||
credentialsProvided++
|
||||
}
|
||||
if opts.JWTToken != "" {
|
||||
credentialsProvided++
|
||||
}
|
||||
if opts.PrivateKey != "" {
|
||||
credentialsProvided++
|
||||
}
|
||||
|
||||
if credentialsProvided == 0 {
|
||||
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
|
||||
}
|
||||
if credentialsProvided > 1 {
|
||||
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client.
|
||||
// New creates a new netbird embedded client
|
||||
func New(opts Options) (*Client, error) {
|
||||
if err := opts.validateCredentials(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
@@ -140,14 +107,9 @@ func New(opts Options) (*Client, error) {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
|
||||
if opts.PrivateKey != "" {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
jwtToken: opts.JWTToken,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
@@ -164,7 +126,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
@@ -173,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{})
|
||||
run := make(chan struct{}, 1)
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
@@ -225,16 +187,6 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns a copy of the internal client config.
|
||||
func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.config == nil {
|
||||
return profilemanager.Config{}, ErrConfigNotInitialized
|
||||
}
|
||||
return *c.config, nil
|
||||
}
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
@@ -259,7 +211,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// ListenTCP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
@@ -280,7 +232,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
return nsnet.ListenTCP(tcpAddr)
|
||||
}
|
||||
|
||||
// ListenUDP listens on the given address in the netbird network.
|
||||
// ListenUDP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy/client"
|
||||
)
|
||||
|
||||
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
return client.WithWebSocketDialer(tlsEnabled, component)
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package bind
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||
|
||||
@@ -1,17 +1,5 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
type Endpoint = wgConn.StdNetEndpoint
|
||||
|
||||
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: e.Addr().AsSlice(),
|
||||
Port: int(e.Port()),
|
||||
Zone: e.Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package bind
|
||||
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
|
||||
)
|
||||
@@ -1,9 +1,6 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -20,9 +17,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
@@ -40,38 +42,37 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
|
||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
RecvChan chan RecvMessage
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
recvChan chan recvMessage
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||
closedChan chan struct{}
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
activityRecorder *ActivityRecorder
|
||||
closedChan chan struct{}
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
activityRecorder *ActivityRecorder
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
recvChan: make(chan recvMessage, 1),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
mtu: mtu,
|
||||
address: address,
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
}
|
||||
|
||||
@@ -82,6 +83,10 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
|
||||
return ib
|
||||
}
|
||||
|
||||
func (s *ICEBind) MTU() uint16 {
|
||||
return s.mtu
|
||||
}
|
||||
|
||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
s.closed = false
|
||||
s.closedChanMu.Lock()
|
||||
@@ -134,16 +139,6 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
delete(b.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||
select {
|
||||
case <-b.closedChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case b.recvChan <- recvMessage{ep, buf}:
|
||||
}
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
b.endpointsMu.Lock()
|
||||
conn, ok := b.endpoints[ep.DstIP()]
|
||||
@@ -276,7 +271,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-c.recvChan:
|
||||
case msg, ok := <-c.RecvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
package bind
|
||||
|
||||
type recvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
)
|
||||
|
||||
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
|
||||
// Do not limit to build only js, because we want to be able to run tests
|
||||
type RelayBindJS struct {
|
||||
*conn.StdNetBind
|
||||
|
||||
recvChan chan recvMessage
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
activityRecorder *ActivityRecorder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewRelayBindJS() *RelayBindJS {
|
||||
return &RelayBindJS{
|
||||
recvChan: make(chan recvMessage, 100),
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
}
|
||||
}
|
||||
|
||||
// Open creates a receive function for handling relay packets in WASM.
|
||||
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
log.Debugf("Open: creating receive function for port %d", uport)
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
|
||||
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-s.recvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
copy(bufs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = conn.Endpoint(msg.Endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Open: receive function created, returning port %d", uport)
|
||||
return []conn.ReceiveFunc{receiveFn}, uport, nil
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) Close() error {
|
||||
if s.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
log.Debugf("close RelayBindJS")
|
||||
s.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case s.recvChan <- recvMessage{ep, buf}:
|
||||
}
|
||||
}
|
||||
|
||||
// Send forwards packets through the relay connection for WASM.
|
||||
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
if ep == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fakeIP := ep.DstIP()
|
||||
|
||||
s.endpointsMu.Lock()
|
||||
relayConn, ok := s.endpoints[fakeIP]
|
||||
s.endpointsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, buf := range bufs {
|
||||
if _, err := relayConn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeIP] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
s.endpointsMu.Lock()
|
||||
defer s.endpointsMu.Unlock()
|
||||
|
||||
delete(s.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, ErrUDPMUXNotSupported
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
|
||||
return s.activityRecorder
|
||||
}
|
||||
@@ -73,44 +73,6 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the existing peer to preserve its allowed IPs
|
||||
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer: %w", err)
|
||||
}
|
||||
|
||||
removePeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
||||
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
||||
}
|
||||
|
||||
//Re-add the peer without the endpoint but same AllowedIPs
|
||||
reAddPeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
AllowedIPs: existingPeer.AllowedIPs,
|
||||
ReplaceAllowedIPs: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
||||
return fmt.Errorf(
|
||||
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
||||
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || windows || freebsd || js || wasip1
|
||||
//go:build linux || windows || freebsd
|
||||
|
||||
package configurer
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !windows && !js
|
||||
//go:build !windows
|
||||
|
||||
package configurer
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type noopListener struct{}
|
||||
|
||||
func (n *noopListener) Accept() (net.Conn, error) {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (n *noopListener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopListener) Addr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
return &noopListener{}, nil
|
||||
}
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -106,67 +106,6 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer key: %w", err)
|
||||
}
|
||||
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get IPC config: %w", err)
|
||||
}
|
||||
|
||||
// Parse current status to get allowed IPs for the peer
|
||||
stats, err := parseStatus(c.deviceName, ipcStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse IPC config: %w", err)
|
||||
}
|
||||
|
||||
var allowedIPs []net.IPNet
|
||||
found := false
|
||||
for _, peer := range stats.Peers {
|
||||
if peer.PublicKey == peerKey {
|
||||
allowedIPs = peer.AllowedIPs
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("peer %s not found", peerKey)
|
||||
}
|
||||
|
||||
// remove the peer from the WireGuard configuration
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
||||
}
|
||||
|
||||
// Build the peer config
|
||||
peer = wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: allowedIPs,
|
||||
}
|
||||
|
||||
config = wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
||||
return fmt.Errorf("remove endpoint address: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
@@ -470,7 +409,7 @@ func toBytes(s string) (int64, error) {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
@@ -101,8 +101,13 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var udpConn net.PacketConn = rawSock
|
||||
if !nbnet.AdvancedRouting() {
|
||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
||||
}
|
||||
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
UDPConn: udpConn,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
@@ -14,15 +12,9 @@ import (
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Bind interface {
|
||||
conn.Bind
|
||||
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
ActivityRecorder() *bind.ActivityRecorder
|
||||
}
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
@@ -30,7 +22,7 @@ type TunNetstackDevice struct {
|
||||
key string
|
||||
mtu uint16
|
||||
listenAddress string
|
||||
bind Bind
|
||||
iceBind *bind.ICEBind
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
@@ -41,7 +33,7 @@ type TunNetstackDevice struct {
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
return &TunNetstackDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -49,7 +41,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
bind: bind,
|
||||
iceBind: iceBind,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,11 +66,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
|
||||
t.device = device.NewDevice(
|
||||
t.filteredDevice,
|
||||
t.bind,
|
||||
t.iceBind,
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
@@ -99,15 +91,11 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpMux, err := t.bind.GetICEMux()
|
||||
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
|
||||
udpMux, err := t.iceBind.GetICEMux()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if udpMux != nil {
|
||||
t.udpMux = udpMux
|
||||
}
|
||||
|
||||
t.udpMux = udpMux
|
||||
log.Debugf("netstack device is ready to use")
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestNewNetstackDevice(t *testing.T) {
|
||||
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
|
||||
|
||||
relayBind := bind.NewRelayBindJS()
|
||||
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
|
||||
|
||||
cfgr, err := nsTun.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create netstack device: %v", err)
|
||||
}
|
||||
if cfgr == nil {
|
||||
t.Fatal("expected non-nil configurer")
|
||||
}
|
||||
}
|
||||
@@ -21,5 +21,4 @@ type WGConfigurer interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
RemoveEndpointAddress(peerKey string) error
|
||||
}
|
||||
|
||||
@@ -148,17 +148,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Removing endpoint address: %s", peerKey)
|
||||
return w.configurer.RemoveEndpointAddress(peerKey)
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
package iface
|
||||
|
||||
// Destroy is a no-op on WASM
|
||||
func (w *WGIface) Destroy() error {
|
||||
return nil
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relayBind := bind.NewRelayBindJS()
|
||||
|
||||
wgIface := &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||
}
|
||||
|
||||
return wgIface, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build !js
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package netstack
|
||||
|
||||
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||
|
||||
// IsEnabled always returns true for js since it's the only mode available
|
||||
func IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func ListenAddr() string {
|
||||
return ""
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||
|
||||
@@ -16,38 +16,28 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
type Bind interface {
|
||||
SetEndpoint(addr netip.Addr, conn net.Conn)
|
||||
RemoveEndpoint(addr netip.Addr)
|
||||
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
|
||||
}
|
||||
|
||||
type ProxyBind struct {
|
||||
bind Bind
|
||||
Bind *bind.ICEBind
|
||||
|
||||
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
|
||||
wgRelayedEndpoint *bind.Endpoint
|
||||
wgCurrentUsed *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
fakeNetIP *netip.AddrPort
|
||||
wgBindEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
|
||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||
p := &ProxyBind{
|
||||
bind: bind,
|
||||
Bind: bind,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
mtu: mtu + bufsize.WGBufferOverhead,
|
||||
}
|
||||
|
||||
return p
|
||||
@@ -56,25 +46,25 @@ func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
|
||||
// AddTurnConn adds a new connection to the bind.
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
|
||||
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
|
||||
// - remoteConn: The established TURN connection to the remote peer
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
fakeNetIP, err := fakeAddress(nbAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
|
||||
p.fakeNetIP = fakeNetIP
|
||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
|
||||
return &net.UDPAddr{
|
||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||
Port: int(p.fakeNetIP.Port()),
|
||||
Zone: p.fakeNetIP.Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||
@@ -86,21 +76,17 @@ func (p *ProxyBind) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
|
||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = p.wgRelayedEndpoint
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
// Start the proxy only once
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Pause() {
|
||||
@@ -108,25 +94,9 @@ func (p *ProxyBind) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
@@ -137,10 +107,6 @@ func (p *ProxyBind) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
if p.remoteConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@@ -154,12 +120,7 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
|
||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
@@ -175,7 +136,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}()
|
||||
|
||||
for {
|
||||
buf := make([]byte, p.mtu)
|
||||
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
@@ -186,13 +147,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
|
||||
p.pausedCond.L.Unlock()
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgBindEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -16,20 +18,15 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
@@ -67,7 +64,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
p.rawConn, err = p.prepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -217,17 +214,57 @@ generatePort:
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
DstIP: localhost,
|
||||
SrcIP: localhost,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
SrcPort: layers.UDPPort(port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
@@ -242,7 +279,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -18,42 +18,41 @@ import (
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
||||
return &ProxyWrapper{
|
||||
wgeBPFProxy: proxy,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
WgeBPFProxy: WgeBPFProxy,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
p.wgEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgRelayedEndpointAddr
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||
@@ -65,18 +64,14 @@ func (p *ProxyWrapper) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
@@ -85,59 +80,45 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedCond.L.Lock()
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (p *ProxyWrapper) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
p.cancel()
|
||||
e.cancel()
|
||||
|
||||
p.closeListener.SetCloseListener(nil)
|
||||
e.closeListener.SetCloseListener(nil)
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
p.pausedCond.L.Unlock()
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
@@ -156,7 +137,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
p.closeListener.Notify()
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
}
|
||||
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
|
||||
31
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
31
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// KernelFactory todo: check eBPF support on FreeBSD
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
@@ -3,25 +3,24 @@ package wgproxy
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
type USPFactory struct {
|
||||
bind proxyBind.Bind
|
||||
mtu uint16
|
||||
bind *bind.ICEBind
|
||||
}
|
||||
|
||||
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
|
||||
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
||||
f := &USPFactory{
|
||||
bind: bind,
|
||||
mtu: mtu,
|
||||
bind: iceBind,
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
return proxyBind.NewProxyBind(w.bind)
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
|
||||
@@ -11,11 +11,6 @@ type Proxy interface {
|
||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
|
||||
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
|
||||
//and rewrite the src address to the endpoint address.
|
||||
//With this logic can avoid the package loss from relayed connections.
|
||||
RedirectAs(endpoint *net.UDPAddr)
|
||||
CloseConn() error
|
||||
SetDisconnectListener(disconnected func())
|
||||
}
|
||||
|
||||
@@ -3,82 +3,54 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
func seedProxies() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
t.Skip("Skipping test as it requires root privileges")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
pEbpf := proxyInstance{
|
||||
name: "ebpf kernel proxy",
|
||||
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||
wgPort: 51831,
|
||||
closeFn: ebpfProxy.Free,
|
||||
}
|
||||
pl = append(pl, pEbpf)
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
pUDP := proxyInstance{
|
||||
name: "udp kernel proxy",
|
||||
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||
wgPort: 51832,
|
||||
closeFn: func() error { return nil },
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "ebpf proxy",
|
||||
proxy: &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
pl = append(pl, pUDP)
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
pEbpf := proxyInstance{
|
||||
name: "ebpf kernel proxy",
|
||||
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||
wgPort: 51831,
|
||||
closeFn: ebpfProxy.Free,
|
||||
}
|
||||
pl = append(pl, pEbpf)
|
||||
|
||||
pUDP := proxyInstance{
|
||||
name: "udp kernel proxy",
|
||||
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||
wgPort: 51832,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pUDP)
|
||||
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
|
||||
pBind := proxyInstance{
|
||||
name: "bind proxy",
|
||||
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||
endpointAddr: endpointAddress,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pBind)
|
||||
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
func seedProxies() ([]proxyInstance, error) {
|
||||
// todo extend with Bind proxy
|
||||
pl := make([]proxyInstance, 0)
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
|
||||
pBind := proxyInstance{
|
||||
name: "bind proxy",
|
||||
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||
endpointAddr: endpointAddress,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pBind)
|
||||
return pl, nil
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
@@ -5,9 +7,12 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -17,14 +22,6 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type proxyInstance struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
wgPort int
|
||||
endpointAddr *net.UDPAddr
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
@@ -81,21 +78,41 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests, err := seedProxyForProxyCloseByRemoteConn()
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
|
||||
},
|
||||
}
|
||||
|
||||
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||
defer func() {
|
||||
_ = relayedConn.Close()
|
||||
}()
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
@@ -107,104 +124,3 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyRedirect todo extend the proxies with Bind proxy
|
||||
func TestProxyRedirect(t *testing.T) {
|
||||
tests, err := seedProxies()
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
|
||||
if err := tt.closeFn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
msgHelloFromRelay := []byte("hello from relay")
|
||||
msgRedirected := [][]byte{
|
||||
[]byte("hello 1. to p2p"),
|
||||
[]byte("hello 2. to p2p"),
|
||||
[]byte("hello 3. to p2p"),
|
||||
}
|
||||
|
||||
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: wgPort})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on udp port: %s", err)
|
||||
}
|
||||
|
||||
relayedServer, _ := net.ListenUDP("udp",
|
||||
&net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 1234,
|
||||
},
|
||||
)
|
||||
|
||||
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||
|
||||
defer func() {
|
||||
_ = dummyWgListener.Close()
|
||||
_ = relayedConn.Close()
|
||||
_ = relayedServer.Close()
|
||||
}()
|
||||
|
||||
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
|
||||
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
|
||||
}
|
||||
|
||||
n, err := dummyWgListener.Read(make([]byte, 1024))
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgHelloFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
|
||||
}
|
||||
|
||||
p2pEndpointAddr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 0, 56),
|
||||
Port: 1234,
|
||||
}
|
||||
proxy.RedirectAs(p2pEndpointAddr)
|
||||
|
||||
for _, msg := range msgRedirected {
|
||||
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(msgRedirected); i++ {
|
||||
buf := make([]byte, 1024)
|
||||
n, rAddr, err := dummyWgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
if rAddr.String() != p2pEndpointAddr.String() {
|
||||
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
|
||||
}
|
||||
if string(buf[:n]) != string(msgRedirected[i]) {
|
||||
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package rawsocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
@@ -23,18 +21,16 @@ type WGUDPProxy struct {
|
||||
localWGListenPort int
|
||||
mtu uint16
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
srcFakerConn *SrcFaker
|
||||
sendPkg func(data []byte) (int, error)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
@@ -45,7 +41,6 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
mtu: mtu,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
return p
|
||||
@@ -66,7 +61,6 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.sendPkg = p.localConn.Write
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
@@ -90,24 +84,15 @@ func (p *WGUDPProxy) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.sendPkg = p.localConn.Write
|
||||
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
@@ -116,35 +101,9 @@ func (p *WGUDPProxy) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// RedirectAs start to use the fake sourced raw socket as package sender
|
||||
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
defer func() {
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}()
|
||||
|
||||
p.paused = false
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create src faker conn: %s", err)
|
||||
// fallback to continue without redirecting
|
||||
p.paused = true
|
||||
return
|
||||
}
|
||||
p.srcFakerConn = srcFakerConn
|
||||
p.sendPkg = p.srcFakerConn.SendPkg
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
@@ -156,8 +115,6 @@ func (p *WGUDPProxy) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) close() error {
|
||||
var result *multierror.Error
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@@ -171,11 +128,7 @@ func (p *WGUDPProxy) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
@@ -183,13 +136,6 @@ func (p *WGUDPProxy) close() error {
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
return cerrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
@@ -248,12 +194,14 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
_, err = p.sendPkg(buf[:n])
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||
)
|
||||
|
||||
var (
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *SrcFaker) Close() error {
|
||||
return f.rawSocket.Close()
|
||||
}
|
||||
|
||||
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
defer func() {
|
||||
if err := f.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return ipH, udpH, nil
|
||||
}
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
|
||||
@@ -240,19 +240,15 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
if r.gpo {
|
||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
}
|
||||
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
@@ -405,7 +401,6 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||
}
|
||||
@@ -417,7 +412,6 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||
return &noopHostConfigurator{}, nil
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ServiceViaMemory struct {
|
||||
|
||||
@@ -31,7 +31,6 @@ const (
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
@@ -103,11 +102,6 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ShutdownState struct{}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
ip4Addrs []netip.Addr
|
||||
ip6Addrs []netip.Addr
|
||||
}
|
||||
|
||||
func newCache() *cache {
|
||||
return &cache{
|
||||
records: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.records[normalizeDomain(domain)]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
return slices.Clone(entry.ip4Addrs), true
|
||||
case dns.TypeAAAA:
|
||||
return slices.Clone(entry.ip6Addrs), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
norm := normalizeDomain(domain)
|
||||
entry, exists := c.records[norm]
|
||||
if !exists {
|
||||
entry = &cacheEntry{}
|
||||
c.records[norm] = entry
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
entry.ip4Addrs = slices.Clone(addrs)
|
||||
case dns.TypeAAAA:
|
||||
entry.ip6Addrs = slices.Clone(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// unset removes cached entries for the given domain and request type.
|
||||
func (c *cache) unset(domain string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.records, normalizeDomain(domain))
|
||||
}
|
||||
|
||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||
// lowercase and fully-qualified (with trailing dot).
|
||||
func normalizeDomain(domain string) string {
|
||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||
return dns.Fqdn(strings.ToLower(domain))
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||
t.Helper()
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse addr %s: %v", s, err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestCacheNormalization(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
// Mixed case, without trailing dot
|
||||
domainInput := "ExAmPlE.CoM"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||
|
||||
// Lookup with lower, with trailing dot
|
||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Lookup with different casing again
|
||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSeparateTypes(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
domain := "test.local"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||
|
||||
c.set(domain, 1 /* A */, ipv4)
|
||||
c.set(domain, 28 /* AAAA */, ipv6)
|
||||
|
||||
got4, ok4 := c.get(domain, 1)
|
||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||
}
|
||||
|
||||
got6, ok6 := c.get(domain, 28)
|
||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||
c := newCache()
|
||||
domain := "clone.test"
|
||||
|
||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||
c.set(domain, 1, src)
|
||||
|
||||
// Mutate source slice; cache should be unaffected
|
||||
src[0] = mustAddr(t, "9.9.9.9")
|
||||
|
||||
got, ok := c.get(domain, 1)
|
||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Mutate returned slice; internal cache should remain unchanged
|
||||
got[0] = mustAddr(t, "4.4.4.4")
|
||||
got2, ok2 := c.get(domain, 1)
|
||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := newCache()
|
||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ type DNSForwarder struct {
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
@@ -57,7 +56,6 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,39 +103,10 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// remove cache entries for domains that no longer appear
|
||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
}
|
||||
|
||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||
// in the old list but not present in the new list.
|
||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||
if f.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newSet := make(map[string]struct{}, len(newEntries))
|
||||
for _, e := range newEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, e := range oldEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
pattern := e.Domain.PunycodeString()
|
||||
if _, ok := newSet[pattern]; !ok {
|
||||
f.cache.unset(pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -202,7 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -314,69 +282,29 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
}
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -648,95 +648,6 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
// Ensures that when the first query succeeds and populates the cache,
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First call resolves successfully and populates cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
// Second call fails upstream; forwarder should serve from cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
// First query: populate cache
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("9.8.7.6")
|
||||
|
||||
// Initial resolution with mixed case to populate cache
|
||||
mixedQuery := "ExAmPlE.CoM"
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -12,18 +11,14 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
listenPort uint16 = 5353
|
||||
listenPortMu sync.RWMutex
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTTL = 60 //seconds
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
ListenPort = 5353
|
||||
dnsTTL = 60 //seconds
|
||||
)
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
@@ -42,18 +37,6 @@ type Manager struct {
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
listenPortMu.RLock()
|
||||
defer listenPortMu.RUnlock()
|
||||
return listenPort
|
||||
}
|
||||
|
||||
func SetListenPort(port uint16) {
|
||||
listenPortMu.Lock()
|
||||
listenPort = port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
@@ -71,7 +54,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -111,7 +94,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
func (m *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []uint16{ListenPort()},
|
||||
Values: []uint16{ListenPort},
|
||||
}
|
||||
|
||||
if m.firewall == nil {
|
||||
|
||||
@@ -198,13 +198,6 @@ type Engine struct {
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
|
||||
// dns forwarder port
|
||||
dnsFwdPort uint16
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -247,7 +240,6 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
dnsFwdPort: dnsfwd.ListenPort(),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
@@ -349,9 +341,6 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -468,7 +457,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
iceCfg := e.createICEConfig()
|
||||
iceCfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||
e.connMgr.Start(e.ctx)
|
||||
@@ -481,22 +477,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1084,7 +1064,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
@@ -1342,7 +1322,14 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
Addr: e.getRosenpassAddr(),
|
||||
PermissiveMode: e.config.RosenpassPermissive,
|
||||
},
|
||||
ICEConfig: e.createICEConfig(),
|
||||
ICEConfig: icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
},
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
@@ -1843,16 +1830,11 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
forwarderPort uint16,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
}
|
||||
|
||||
if forwarderPort > 0 {
|
||||
dnsfwd.SetListenPort(forwarderPort)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1864,20 +1846,16 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
case e.dnsFwdPort != forwarderPort:
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||
e.dnsFwdPort = forwarderPort
|
||||
|
||||
default:
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
@@ -1887,20 +1865,6 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
// stop and start the forwarder to apply the new port
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
// createICEConfig creates ICE configuration for non-WASM environments
|
||||
func (e *Engine) createICEConfig() icemaker.Config {
|
||||
return icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
// createICEConfig creates ICE configuration for WASM environment.
|
||||
func (e *Engine) createICEConfig() icemaker.Config {
|
||||
cfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
@@ -27,10 +27,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -46,8 +42,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -105,10 +103,6 @@ type MockWGIface struct {
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -1590,7 +1584,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ type wgIfaceBase interface {
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
|
||||
@@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
// No-op for WASM - network changes don't apply
|
||||
return nil
|
||||
}
|
||||
@@ -28,6 +28,10 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
StatusRecorder *Status
|
||||
Signaler *Signaler
|
||||
@@ -113,8 +117,6 @@ type Conn struct {
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
|
||||
endpointUpdater *EndpointUpdater
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -127,18 +129,17 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -171,9 +172,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||
@@ -248,7 +249,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
if err := conn.endpointUpdater.RemoveWgPeer(); err != nil {
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
@@ -374,19 +375,12 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.Log.Debugf("redirect packets from relayed conn to WireGuard")
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -415,8 +409,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.dumpState.SwitchToRelay()
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
@@ -425,14 +418,10 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
@@ -488,8 +477,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
@@ -526,9 +514,6 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
conn.currentConnPriority = conntype.None
|
||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -560,6 +545,17 @@ func (conn *Conn) onGuardEvent() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
|
||||
presharedKey := conn.presharedKey(remoteRPKey)
|
||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
||||
conn.config.WgConfig.RemoteKey,
|
||||
conn.config.WgConfig.AllowedIps,
|
||||
defaultWgKeepAlive,
|
||||
addr,
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -702,6 +698,10 @@ func (conn *Conn) isICEActive() bool {
|
||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
|
||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOfferChan := make(chan struct{})
|
||||
onNewOffeChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteOffer(OfferAnswer{
|
||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOfferChan:
|
||||
case <-onNewOffeChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOfferChan := make(chan struct{})
|
||||
onNewOffeChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteAnswer(OfferAnswer{
|
||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOfferChan:
|
||||
case <-onNewOffeChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
fallbackDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
type EndpointUpdater struct {
|
||||
log *logrus.Entry
|
||||
wgConfig WgConfig
|
||||
initiator bool
|
||||
|
||||
// mu protects updateWireGuardPeer and cancelFunc
|
||||
mu sync.Mutex
|
||||
cancelFunc func()
|
||||
updateWg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater {
|
||||
return &EndpointUpdater{
|
||||
log: log,
|
||||
wgConfig: wgConfig,
|
||||
initiator: initiator,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||
// The initiator immediately configures the endpoint, while the non-initiator
|
||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.initiator {
|
||||
e.log.Debugf("configure up WireGuard as initiatr")
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
}
|
||||
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
return e.updateWireGuardPeer(nil, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.cancelFunc()
|
||||
e.cancelFunc = nil
|
||||
e.updateWg.Wait()
|
||||
}
|
||||
|
||||
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
|
||||
func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) {
|
||||
defer e.updateWg.Done()
|
||||
t := time.NewTimer(fallbackDelay)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
e.mu.Lock()
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
return e.wgConfig.WgInterface.UpdatePeer(
|
||||
e.wgConfig.RemoteKey,
|
||||
e.wgConfig.AllowedIps,
|
||||
defaultWgKeepAlive,
|
||||
endpoint,
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
||||
)
|
||||
|
||||
func GetICEMonitorPeriod() time.Duration {
|
||||
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
||||
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
return defaultCandidatesMonitorPeriod
|
||||
}
|
||||
@@ -3,8 +3,6 @@ package guard
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +14,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ICEMonitor struct {
|
||||
@@ -25,19 +23,16 @@ type ICEMonitor struct {
|
||||
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig icemaker.Config
|
||||
tickerPeriod time.Duration
|
||||
|
||||
currentCandidatesAddress []string
|
||||
candidatesMu sync.Mutex
|
||||
currentCandidates []ice.Candidate
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
||||
log.Debugf("prepare ICE monitor with period: %s", period)
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||
cm := &ICEMonitor{
|
||||
ReconnectCh: make(chan struct{}, 1),
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
iceConfig: config,
|
||||
tickerPeriod: period,
|
||||
}
|
||||
return cm
|
||||
}
|
||||
@@ -49,12 +44,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
return
|
||||
}
|
||||
|
||||
// Initial check to populate the candidates for later comparison
|
||||
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
||||
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cm.tickerPeriod)
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -125,21 +115,16 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
||||
cm.candidatesMu.Lock()
|
||||
defer cm.candidatesMu.Unlock()
|
||||
|
||||
newAddresses := make([]string, len(newCandidates))
|
||||
for i, c := range newCandidates {
|
||||
newAddresses[i] = c.Address()
|
||||
}
|
||||
sort.Strings(newAddresses)
|
||||
|
||||
if len(cm.currentCandidatesAddress) != len(newAddresses) {
|
||||
cm.currentCandidatesAddress = newAddresses
|
||||
if len(cm.currentCandidates) != len(newCandidates) {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
|
||||
// Compare elements
|
||||
if !slices.Equal(cm.currentCandidatesAddress, newAddresses) {
|
||||
cm.currentCandidatesAddress = newAddresses
|
||||
return true
|
||||
for i, candidate := range cm.currentCandidates {
|
||||
if candidate.Address() != newCandidates[i].Address() {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
@@ -44,19 +44,13 @@ type OfferAnswer struct {
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||
// and it will only keep the latest message if multiple offers are received in a short time
|
||||
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||
// and also to avoid processing old offers if multiple offers are received in a short time
|
||||
// the listener will always process the latest offer
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
onNewOfferListeners []*OfferListener
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
@@ -76,39 +70,28 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.relayListener = NewAsyncOfferListener(offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.iceListener = offer
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
l := NewOfferListener(offer)
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
if err := h.sendAnswer(); err != nil {
|
||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||
continue
|
||||
}
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
h.log.Infof("stop listening for remote offers and answers")
|
||||
|
||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
||||
return oa.SessionID.String()
|
||||
}
|
||||
|
||||
type AsyncOfferListener struct {
|
||||
type OfferListener struct {
|
||||
fn callbackFunc
|
||||
running bool
|
||||
latest *OfferAnswer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
||||
return &AsyncOfferListener{
|
||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||
return &OfferListener{
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
||||
runChan <- struct{}{}
|
||||
}
|
||||
|
||||
hl := NewAsyncOfferListener(longRunningFn)
|
||||
hl := NewOfferListener(longRunningFn)
|
||||
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
|
||||
@@ -18,5 +18,4 @@ type WGIface interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
Address() wgaddr.Address
|
||||
RemoveEndpointAddress(key string) error
|
||||
}
|
||||
|
||||
@@ -92,16 +92,23 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
if w.agent != nil || w.agentConnecting {
|
||||
if w.agentConnecting {
|
||||
w.log.Debugf("agent connection is in progress, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if w.agent != nil {
|
||||
// backward compatibility with old clients that do not send session ID
|
||||
if remoteOfferAnswer.SessionID == nil {
|
||||
w.log.Debugf("agent already exists, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
@@ -109,12 +116,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
@@ -125,23 +126,18 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
||||
}
|
||||
w.log.Debugf("recreate ICE agent")
|
||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.agent = agent
|
||||
w.agentDialerCancel = dialerCancel
|
||||
w.agentConnecting = true
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
} else {
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||
}
|
||||
@@ -297,6 +293,9 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.muxAgent.Lock()
|
||||
w.agentConnecting = false
|
||||
w.lastSuccess = time.Now()
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
// todo: the potential problem is a race between the onConnectionStateChange
|
||||
@@ -310,17 +309,16 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
@@ -397,12 +395,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
|
||||
@@ -24,8 +24,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const dnsTimeout = 8 * time.Second
|
||||
@@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -108,10 +108,6 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||
|
||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||
}
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
@@ -212,7 +208,7 @@ func (m *DefaultManager) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
@@ -223,7 +219,7 @@ func (m *DefaultManager) Init() error {
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
@@ -289,15 +285,11 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
nbnet.SetVPNInterfaceName("")
|
||||
}
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const localSubnetsCacheTTL = 15 * time.Minute
|
||||
@@ -95,9 +96,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hooks.RemoveWriteHooks()
|
||||
hooks.RemoveCloseHooks()
|
||||
hooks.RemoveAddressRemoveHooks()
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := r.refCounter.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
@@ -289,7 +290,12 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
@@ -298,7 +304,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID hooks.ConnectionID) error {
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
@@ -311,20 +317,36 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
|
||||
continue
|
||||
}
|
||||
if err := beforeHook("init", prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||
}
|
||||
}
|
||||
|
||||
hooks.AddWriteHook(beforeHook)
|
||||
hooks.AddCloseHook(afterHook)
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
var merr *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
@@ -144,11 +143,10 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||
@@ -343,11 +341,10 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
@@ -487,11 +484,10 @@ func setupTestEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
var ErrRouteNotSupported = errors.New("route operations not supported on js")
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||
return []DetailedRoute{}, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error {
|
||||
return nil
|
||||
}
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// IPRule contains IP rule information for debugging
|
||||
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
|
||||
if !advancedRouting {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if !advancedRouting {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user