mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 09:03:54 -04:00
Compare commits
37 Commits
snyk-fix-9
...
feature/me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee8739760d | ||
|
|
63b003f255 | ||
|
|
000e99e7f3 | ||
|
|
0d2e67983a | ||
|
|
5151f19d29 | ||
|
|
bedd3cabc9 | ||
|
|
d35a845dbd | ||
|
|
4e03f708a4 | ||
|
|
654aa9581d | ||
|
|
9021bb512b | ||
|
|
768332820e | ||
|
|
229c65ffa1 | ||
|
|
4d33567888 | ||
|
|
88467883fc | ||
|
|
954f40991f | ||
|
|
34341d95a9 | ||
|
|
e7b5537dcc | ||
|
|
95794f53ce | ||
|
|
9bcd3ebed4 | ||
|
|
b85045e723 | ||
|
|
4d7e59f199 | ||
|
|
b5daec3b51 | ||
|
|
5e1a40c33f | ||
|
|
e8d301fdc9 | ||
|
|
17bab881f7 | ||
|
|
25ed58328a | ||
|
|
644ed4b934 | ||
|
|
58faa341d2 | ||
|
|
5853b5553c | ||
|
|
998fb30e1e | ||
|
|
e254b4cde5 | ||
|
|
ead1c618ba | ||
|
|
55126f990c | ||
|
|
90577682e4 | ||
|
|
dc30dcacce | ||
|
|
2c87fa6236 | ||
|
|
ec8d83ade4 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.22"
|
SIGN_PIPE_VER: "v0.0.23"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
67
.github/workflows/wasm-build-validation.yml
vendored
Normal file
67
.github/workflows/wasm-build-validation.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
name: Wasm
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
js_lint:
|
||||||
|
name: "JS / Lint"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
|
- name: Install golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||||
|
with:
|
||||||
|
version: latest
|
||||||
|
install-mode: binary
|
||||||
|
skip-cache: true
|
||||||
|
skip-pkg-cache: true
|
||||||
|
skip-build-cache: true
|
||||||
|
- name: Run golangci-lint for WASM
|
||||||
|
run: |
|
||||||
|
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
js_build:
|
||||||
|
name: "JS / Build"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
- name: Build Wasm client
|
||||||
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
|
env:
|
||||||
|
CGO_ENABLED: 0
|
||||||
|
- name: Check Wasm build size
|
||||||
|
run: |
|
||||||
|
echo "Wasm build size:"
|
||||||
|
ls -lh netbird.wasm
|
||||||
|
|
||||||
|
SIZE=$(stat -c%s netbird.wasm)
|
||||||
|
SIZE_MB=$((SIZE / 1024 / 1024))
|
||||||
|
|
||||||
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
|
if [ ${SIZE} -gt 52428800 ]; then
|
||||||
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
@@ -2,6 +2,18 @@ version: 2
|
|||||||
|
|
||||||
project_name: netbird
|
project_name: netbird
|
||||||
builds:
|
builds:
|
||||||
|
- id: netbird-wasm
|
||||||
|
dir: client/wasm/cmd
|
||||||
|
binary: netbird
|
||||||
|
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
|
||||||
|
goos:
|
||||||
|
- js
|
||||||
|
goarch:
|
||||||
|
- wasm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird
|
- id: netbird
|
||||||
dir: client
|
dir: client
|
||||||
binary: netbird
|
binary: netbird
|
||||||
@@ -115,6 +127,11 @@ archives:
|
|||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
- netbird-static
|
- netbird-static
|
||||||
|
- id: netbird-wasm
|
||||||
|
builds:
|
||||||
|
- netbird-wasm
|
||||||
|
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||||
|
format: binary
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<br/>
|
<br/>
|
||||||
<br/>
|
<br/>
|
||||||
@@ -52,7 +53,7 @@
|
|||||||
|
|
||||||
### Open Source Network Security in a Single Platform
|
### Open Source Network Security in a Single Platform
|
||||||
|
|
||||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### NetBird on Lawrence Systems (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ ENV \
|
|||||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type ErrListener interface {
|
|||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(string)
|
||||||
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth can register or login new client
|
// Auth can register or login new client
|
||||||
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
}
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
8
client/cmd/debug_js.go
Normal file
8
client/cmd/debug_js.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// SetupDebugHandler is a no-op for WASM
|
||||||
|
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
|
||||||
|
// Debug handler not needed for WASM
|
||||||
|
}
|
||||||
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
|||||||
|
|
||||||
// DialClientGRPCServer returns client connection to the daemon server.
|
// DialClientGRPCServer returns client connection to the daemon server.
|
||||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return grpc.DialContext(
|
return grpc.DialContext(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||||
|
WaitForReady: func() *bool { b := true; return &b }(),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,23 +23,29 @@ import (
|
|||||||
|
|
||||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
var ErrClientNotStarted = errors.New("client not started")
|
var ErrClientNotStarted = errors.New("client not started")
|
||||||
|
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
setupKey string
|
setupKey string
|
||||||
|
jwtToken string
|
||||||
connect *internal.ConnectClient
|
connect *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options configures a new Client
|
// Options configures a new Client.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// DeviceName is this peer's name in the network
|
// DeviceName is this peer's name in the network
|
||||||
DeviceName string
|
DeviceName string
|
||||||
// SetupKey is used for authentication
|
// SetupKey is used for authentication
|
||||||
SetupKey string
|
SetupKey string
|
||||||
|
// JWTToken is used for JWT-based authentication
|
||||||
|
JWTToken string
|
||||||
|
// PrivateKey is used for direct private key authentication
|
||||||
|
PrivateKey string
|
||||||
// ManagementURL overrides the default management server URL
|
// ManagementURL overrides the default management server URL
|
||||||
ManagementURL string
|
ManagementURL string
|
||||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||||
@@ -58,8 +64,35 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new netbird embedded client
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
|
func (opts *Options) validateCredentials() error {
|
||||||
|
credentialsProvided := 0
|
||||||
|
if opts.SetupKey != "" {
|
||||||
|
credentialsProvided++
|
||||||
|
}
|
||||||
|
if opts.JWTToken != "" {
|
||||||
|
credentialsProvided++
|
||||||
|
}
|
||||||
|
if opts.PrivateKey != "" {
|
||||||
|
credentialsProvided++
|
||||||
|
}
|
||||||
|
|
||||||
|
if credentialsProvided == 0 {
|
||||||
|
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
|
||||||
|
}
|
||||||
|
if credentialsProvided > 1 {
|
||||||
|
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new netbird embedded client.
|
||||||
func New(opts Options) (*Client, error) {
|
func New(opts Options) (*Client, error) {
|
||||||
|
if err := opts.validateCredentials(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if opts.LogOutput != nil {
|
if opts.LogOutput != nil {
|
||||||
logrus.SetOutput(opts.LogOutput)
|
logrus.SetOutput(opts.LogOutput)
|
||||||
}
|
}
|
||||||
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
|
|||||||
return nil, fmt.Errorf("create config: %w", err)
|
return nil, fmt.Errorf("create config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.PrivateKey != "" {
|
||||||
|
config.PrivateKey = opts.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
deviceName: opts.DeviceName,
|
deviceName: opts.DeviceName,
|
||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
ctx := internal.CtxInitState(context.Background())
|
ctx := internal.CtxInitState(context.Background())
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
run := make(chan struct{}, 1)
|
run := make(chan struct{})
|
||||||
clientErr := make(chan error, 1)
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run); err != nil {
|
if err := client.Run(run); err != nil {
|
||||||
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfig returns a copy of the internal client config.
|
||||||
|
func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.config == nil {
|
||||||
|
return profilemanager.Config{}, ErrConfigNotInitialized
|
||||||
|
}
|
||||||
|
return *c.config, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Dial dials a network address in the netbird network.
|
// Dial dials a network address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
|||||||
return nsnet.DialContext(ctx, network, address)
|
return nsnet.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenTCP listens on the given address in the netbird network
|
// ListenTCP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
nsnet, addr, err := c.getNet()
|
nsnet, addr, err := c.getNet()
|
||||||
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
|||||||
return nsnet.ListenTCP(tcpAddr)
|
return nsnet.ListenTCP(tcpAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP listens on the given address in the netbird network
|
// ListenUDP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||||
nsnet, addr, err := c.getNet()
|
nsnet, addr, err := c.getNet()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
func isIptablesSupported() bool {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -4,15 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os/user"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -21,35 +15,9 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
currentUser, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
|
||||||
if currentUser.Uid != "0" {
|
|
||||||
log.Debug("Not running as root, using standard dialer")
|
|
||||||
dialer := &net.Dialer{}
|
|
||||||
return dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
b.MaxElapsedTime = 10 * time.Second
|
b.MaxElapsedTime = 10 * time.Second
|
||||||
@@ -57,9 +25,12 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
if tlsEnabled {
|
// for js, the outer websocket layer takes care of tls
|
||||||
|
if tlsEnabled && runtime.GOOS != "js" {
|
||||||
certPool, err := x509.SystemCertPool()
|
certPool, err := x509.SystemCertPool()
|
||||||
if err != nil || certPool == nil {
|
if err != nil || certPool == nil {
|
||||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||||
@@ -71,14 +42,14 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
conn, err := grpc.DialContext(
|
||||||
connCtx,
|
connCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
44
client/grpc/dialer_generic.go
Normal file
44
client/grpc/dialer_generic.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||||
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||||
|
if currentUser.Uid != "0" {
|
||||||
|
log.Debug("Not running as root, using standard dialer")
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to dial: %s", err)
|
||||||
|
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
13
client/grpc/dialer_js.go
Normal file
13
client/grpc/dialer_js.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util/wsproxy/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
|
||||||
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
|
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||||
|
return client.WithWebSocketDialer(tlsEnabled, component)
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ package bind
|
|||||||
import (
|
import (
|
||||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||||
|
|||||||
@@ -1,5 +1,17 @@
|
|||||||
package bind
|
package bind
|
||||||
|
|
||||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
type Endpoint = wgConn.StdNetEndpoint
|
type Endpoint = wgConn.StdNetEndpoint
|
||||||
|
|
||||||
|
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: e.Addr().AsSlice(),
|
||||||
|
Port: int(e.Port()),
|
||||||
|
Zone: e.Addr().Zone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
7
client/iface/bind/error.go
Normal file
7
client/iface/bind/error.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
|
||||||
|
)
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -17,14 +20,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
|
||||||
Endpoint *Endpoint
|
|
||||||
Buffer []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type receiverCreator struct {
|
type receiverCreator struct {
|
||||||
iceBind *ICEBind
|
iceBind *ICEBind
|
||||||
}
|
}
|
||||||
@@ -42,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
|
|||||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||||
type ICEBind struct {
|
type ICEBind struct {
|
||||||
*wgConn.StdNetBind
|
*wgConn.StdNetBind
|
||||||
RecvChan chan RecvMessage
|
|
||||||
|
|
||||||
transportNet transport.Net
|
transportNet transport.Net
|
||||||
filterFn udpmux.FilterFn
|
filterFn udpmux.FilterFn
|
||||||
endpoints map[netip.Addr]net.Conn
|
address wgaddr.Address
|
||||||
endpointsMu sync.Mutex
|
mtu uint16
|
||||||
|
|
||||||
|
endpoints map[netip.Addr]net.Conn
|
||||||
|
endpointsMu sync.Mutex
|
||||||
|
recvChan chan recvMessage
|
||||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||||
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||||
closedChan chan struct{}
|
closedChan chan struct{}
|
||||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
muUDPMux sync.Mutex
|
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
|
||||||
address wgaddr.Address
|
|
||||||
mtu uint16
|
|
||||||
activityRecorder *ActivityRecorder
|
activityRecorder *ActivityRecorder
|
||||||
|
|
||||||
|
muUDPMux sync.Mutex
|
||||||
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
RecvChan: make(chan RecvMessage, 1),
|
|
||||||
transportNet: transportNet,
|
transportNet: transportNet,
|
||||||
filterFn: filterFn,
|
filterFn: filterFn,
|
||||||
|
address: address,
|
||||||
|
mtu: mtu,
|
||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
recvChan: make(chan recvMessage, 1),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
mtu: mtu,
|
|
||||||
address: address,
|
|
||||||
activityRecorder: NewActivityRecorder(),
|
activityRecorder: NewActivityRecorder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
|
|||||||
return ib
|
return ib
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ICEBind) MTU() uint16 {
|
|
||||||
return s.mtu
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
s.closed = false
|
s.closed = false
|
||||||
s.closedChanMu.Lock()
|
s.closedChanMu.Lock()
|
||||||
@@ -139,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
|||||||
delete(b.endpoints, fakeIP)
|
delete(b.endpoints, fakeIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||||
|
select {
|
||||||
|
case <-b.closedChan:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case b.recvChan <- recvMessage{ep, buf}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
b.endpointsMu.Lock()
|
b.endpointsMu.Lock()
|
||||||
conn, ok := b.endpoints[ep.DstIP()]
|
conn, ok := b.endpoints[ep.DstIP()]
|
||||||
@@ -271,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
select {
|
select {
|
||||||
case <-c.closedChan:
|
case <-c.closedChan:
|
||||||
return 0, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
case msg, ok := <-c.RecvChan:
|
case msg, ok := <-c.recvChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|||||||
6
client/iface/bind/recv_msg.go
Normal file
6
client/iface/bind/recv_msg.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
type recvMessage struct {
|
||||||
|
Endpoint *Endpoint
|
||||||
|
Buffer []byte
|
||||||
|
}
|
||||||
125
client/iface/bind/relay_bind.go
Normal file
125
client/iface/bind/relay_bind.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
|
||||||
|
// Do not limit to build only js, because we want to be able to run tests
|
||||||
|
type RelayBindJS struct {
|
||||||
|
*conn.StdNetBind
|
||||||
|
|
||||||
|
recvChan chan recvMessage
|
||||||
|
endpoints map[netip.Addr]net.Conn
|
||||||
|
endpointsMu sync.Mutex
|
||||||
|
activityRecorder *ActivityRecorder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayBindJS() *RelayBindJS {
|
||||||
|
return &RelayBindJS{
|
||||||
|
recvChan: make(chan recvMessage, 100),
|
||||||
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
activityRecorder: NewActivityRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open creates a receive function for handling relay packets in WASM.
|
||||||
|
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||||
|
log.Debugf("Open: creating receive function for port %d", uport)
|
||||||
|
|
||||||
|
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
case msg, ok := <-s.recvChan:
|
||||||
|
if !ok {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
copy(bufs[0], msg.Buffer)
|
||||||
|
sizes[0] = len(msg.Buffer)
|
||||||
|
eps[0] = conn.Endpoint(msg.Endpoint)
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Open: receive function created, returning port %d", uport)
|
||||||
|
return []conn.ReceiveFunc{receiveFn}, uport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) Close() error {
|
||||||
|
if s.cancel == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("close RelayBindJS")
|
||||||
|
s.cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case s.recvChan <- recvMessage{ep, buf}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send forwards packets through the relay connection for WASM.
|
||||||
|
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
|
if ep == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP := ep.DstIP()
|
||||||
|
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
relayConn, ok := s.endpoints[fakeIP]
|
||||||
|
s.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, buf := range bufs {
|
||||||
|
if _, err := relayConn.Write(buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||||
|
b.endpointsMu.Lock()
|
||||||
|
b.endpoints[fakeIP] = conn
|
||||||
|
b.endpointsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
defer s.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
delete(s.endpoints, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
|
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
|
return nil, ErrUDPMUXNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
|
||||||
|
return s.activityRecorder
|
||||||
|
}
|
||||||
@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the existing peer to preserve its allowed IPs
|
||||||
|
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
removePeerCfg := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
||||||
|
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Re-add the peer without the endpoint but same AllowedIPs
|
||||||
|
reAddPeerCfg := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
AllowedIPs: existingPeer.AllowedIPs,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
||||||
|
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux || windows || freebsd
|
//go:build linux || windows || freebsd || js || wasip1
|
||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !windows
|
//go:build !windows && !js
|
||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
|
|||||||
23
client/iface/configurer/uapi_js.go
Normal file
23
client/iface/configurer/uapi_js.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type noopListener struct{}
|
||||||
|
|
||||||
|
func (n *noopListener) Accept() (net.Conn, error) {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopListener) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopListener) Addr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openUAPI(deviceName string) (net.Listener, error) {
|
||||||
|
return &noopListener{}, nil
|
||||||
|
}
|
||||||
@@ -17,8 +17,8 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse peer key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipcStr, err := c.device.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get IPC config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse current status to get allowed IPs for the peer
|
||||||
|
stats, err := parseStatus(c.deviceName, ipcStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse IPC config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allowedIPs []net.IPNet
|
||||||
|
found := false
|
||||||
|
for _, peer := range stats.Peers {
|
||||||
|
if peer.PublicKey == peerKey {
|
||||||
|
allowedIPs = peer.AllowedIPs
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("peer %s not found", peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the peer from the WireGuard configuration
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||||
|
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the peer config
|
||||||
|
peer = wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: allowedIPs,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
||||||
|
return fmt.Errorf("remove endpoint address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -409,7 +470,7 @@ func toBytes(s string) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/sharedsock"
|
"github.com/netbirdio/netbird/sharedsock"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunKernelDevice struct {
|
type TunKernelDevice struct {
|
||||||
@@ -101,13 +101,8 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var udpConn net.PacketConn = rawSock
|
|
||||||
if !nbnet.AdvancedRouting() {
|
|
||||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
|
||||||
}
|
|
||||||
|
|
||||||
bindParams := udpmux.UniversalUDPMuxParams{
|
bindParams := udpmux.UniversalUDPMuxParams{
|
||||||
UDPConn: udpConn,
|
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||||
Net: t.transportNet,
|
Net: t.transportNet,
|
||||||
FilterFn: t.filterFn,
|
FilterFn: t.filterFn,
|
||||||
WGAddress: t.address,
|
WGAddress: t.address,
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
@@ -12,9 +14,15 @@ import (
|
|||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Bind interface {
|
||||||
|
conn.Bind
|
||||||
|
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
|
ActivityRecorder() *bind.ActivityRecorder
|
||||||
|
}
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
name string
|
name string
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
@@ -22,7 +30,7 @@ type TunNetstackDevice struct {
|
|||||||
key string
|
key string
|
||||||
mtu uint16
|
mtu uint16
|
||||||
listenAddress string
|
listenAddress string
|
||||||
iceBind *bind.ICEBind
|
bind Bind
|
||||||
|
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
@@ -33,7 +41,7 @@ type TunNetstackDevice struct {
|
|||||||
net *netstack.Net
|
net *netstack.Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
|
||||||
return &TunNetstackDevice{
|
return &TunNetstackDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
|
|||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
iceBind: iceBind,
|
bind: bind,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
|
|
||||||
t.device = device.NewDevice(
|
t.device = device.NewDevice(
|
||||||
t.filteredDevice,
|
t.filteredDevice,
|
||||||
t.iceBind,
|
t.bind,
|
||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
@@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
udpMux, err := t.iceBind.GetICEMux()
|
udpMux, err := t.bind.GetICEMux()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t.udpMux = udpMux
|
|
||||||
|
if udpMux != nil {
|
||||||
|
t.udpMux = udpMux
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("netstack device is ready to use")
|
log.Debugf("netstack device is ready to use")
|
||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|||||||
27
client/iface/device/device_netstack_test.go
Normal file
27
client/iface/device/device_netstack_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewNetstackDevice(t *testing.T) {
|
||||||
|
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
|
||||||
|
|
||||||
|
relayBind := bind.NewRelayBindJS()
|
||||||
|
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
|
||||||
|
|
||||||
|
cfgr, err := nsTun.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create netstack device: %v", err)
|
||||||
|
}
|
||||||
|
if cfgr == nil {
|
||||||
|
t.Fatal("expected non-nil configurer")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,4 +21,5 @@ type WGConfigurer interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
|
RemoveEndpointAddress(peerKey string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
|||||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.configurer == nil {
|
||||||
|
return ErrIfaceNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Removing endpoint address: %s", peerKey)
|
||||||
|
return w.configurer.RemoveEndpointAddress(peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
|
|||||||
6
client/iface/iface_destroy_js.go
Normal file
6
client/iface/iface_destroy_js.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
// Destroy is a no-op on WASM
|
||||||
|
func (w *WGIface) Destroy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: tun,
|
tun: tun,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|||||||
41
client/iface/iface_new_freebsd.go
Normal file
41
client/iface/iface_new_freebsd.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build freebsd
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.ModuleTunIsLoaded() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|||||||
27
client/iface/iface_new_js.go
Normal file
27
client/iface/iface_new_js.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
relayBind := bind.NewRelayBindJS()
|
||||||
|
|
||||||
|
wgIface := &WGIface{
|
||||||
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||||
|
userspaceBind: true,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (linux && !android) || freebsd
|
//go:build linux && !android
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: tun,
|
tun: tun,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
12
client/iface/netstack/env_js.go
Normal file
12
client/iface/netstack/env_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package netstack
|
||||||
|
|
||||||
|
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||||
|
|
||||||
|
// IsEnabled always returns true for js since it's the only mode available
|
||||||
|
func IsEnabled() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenAddr() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
package udpmux
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||||
|
|||||||
@@ -16,28 +16,38 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyBind struct {
|
type Bind interface {
|
||||||
Bind *bind.ICEBind
|
SetEndpoint(addr netip.Addr, conn net.Conn)
|
||||||
|
RemoveEndpoint(addr netip.Addr)
|
||||||
fakeNetIP *netip.AddrPort
|
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
|
||||||
wgBindEndpoint *bind.Endpoint
|
|
||||||
remoteConn net.Conn
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
closeMu sync.Mutex
|
|
||||||
closed bool
|
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
|
||||||
paused bool
|
|
||||||
isStarted bool
|
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
type ProxyBind struct {
|
||||||
|
bind Bind
|
||||||
|
|
||||||
|
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
|
||||||
|
wgRelayedEndpoint *bind.Endpoint
|
||||||
|
wgCurrentUsed *bind.Endpoint
|
||||||
|
remoteConn net.Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
closeMu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
|
||||||
|
paused bool
|
||||||
|
pausedCond *sync.Cond
|
||||||
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
|
mtu uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
|
||||||
p := &ProxyBind{
|
p := &ProxyBind{
|
||||||
Bind: bind,
|
bind: bind,
|
||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
|
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
mtu: mtu + bufsize.WGBufferOverhead,
|
||||||
}
|
}
|
||||||
|
|
||||||
return p
|
return p
|
||||||
@@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
|||||||
// AddTurnConn adds a new connection to the bind.
|
// AddTurnConn adds a new connection to the bind.
|
||||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||||
// WireGuard configuration.
|
// WireGuard configuration.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
|
||||||
|
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
|
||||||
|
// - remoteConn: The established TURN connection to the remote peer
|
||||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
fakeNetIP, err := fakeAddress(nbAddr)
|
fakeNetIP, err := fakeAddress(nbAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||||
p.fakeNetIP = fakeNetIP
|
|
||||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||||
return &net.UDPAddr{
|
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
|
||||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
|
||||||
Port: int(p.fakeNetIP.Port()),
|
|
||||||
Zone: p.fakeNetIP.Addr().Zone(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||||
@@ -76,17 +86,21 @@ func (p *ProxyBind) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
|
||||||
|
p.wgCurrentUsed = p.wgRelayedEndpoint
|
||||||
|
|
||||||
// Start the proxy only once
|
// Start the proxy only once
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
go p.proxyToLocal(p.ctx)
|
go p.proxyToLocal(p.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) Pause() {
|
func (p *ProxyBind) Pause() {
|
||||||
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = true
|
p.paused = true
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
|
||||||
|
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||||
|
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||||
|
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||||
|
return &bind.Endpoint{AddrPort: addrPort}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) CloseConn() error {
|
func (p *ProxyBind) CloseConn() error {
|
||||||
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) close() error {
|
func (p *ProxyBind) close() error {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
p.closeMu.Lock()
|
p.closeMu.Lock()
|
||||||
defer p.closeMu.Unlock()
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
|
|||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
|
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
|
||||||
|
|
||||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||||
return rErr
|
return rErr
|
||||||
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
|
buf := make([]byte, p.mtu)
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
if p.paused {
|
for p.paused {
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.Wait()
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := bind.RecvMessage{
|
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
|
||||||
Endpoint: p.wgBindEndpoint,
|
p.pausedCond.L.Unlock()
|
||||||
Buffer: buf[:n],
|
|
||||||
}
|
|
||||||
p.Bind.RecvChan <- msg
|
|
||||||
p.pausedMu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@@ -18,15 +16,20 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
loopbackAddr = "127.0.0.1"
|
loopbackAddr = "127.0.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||||
|
)
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.rawConn, err = p.prepareSenderRawSocket()
|
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -214,57 +217,17 @@ generatePort:
|
|||||||
return p.lastUsedPort, nil
|
return p.lastUsedPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||||
// Create a raw socket.
|
|
||||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
|
||||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind the socket to the "lo" interface.
|
|
||||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
|
||||||
err = nbnet.SetSocketOpt(fd)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert the file descriptor to a PacketConn.
|
|
||||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
|
||||||
if file == nil {
|
|
||||||
return nil, fmt.Errorf("converting fd to file failed")
|
|
||||||
}
|
|
||||||
packetConn, err := net.FilePacketConn(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return packetConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
|
||||||
localhost := net.ParseIP("127.0.0.1")
|
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
payload := gopacket.Payload(data)
|
||||||
ipH := &layers.IPv4{
|
ipH := &layers.IPv4{
|
||||||
DstIP: localhost,
|
DstIP: localHostNetIP,
|
||||||
SrcIP: localhost,
|
SrcIP: endpointAddr.IP,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
udpH := &layers.UDP{
|
udpH := &layers.UDP{
|
||||||
SrcPort: layers.UDPPort(port),
|
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("serialize layers: %w", err)
|
return fmt.Errorf("serialize layers: %w", err)
|
||||||
}
|
}
|
||||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||||
return fmt.Errorf("write to raw conn: %w", err)
|
return fmt.Errorf("write to raw conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -18,41 +18,42 @@ import (
|
|||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
type ProxyWrapper struct {
|
type ProxyWrapper struct {
|
||||||
WgeBPFProxy *WGEBPFProxy
|
wgeBPFProxy *WGEBPFProxy
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgEndpointAddr *net.UDPAddr
|
wgRelayedEndpointAddr *net.UDPAddr
|
||||||
|
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
paused bool
|
||||||
paused bool
|
pausedCond *sync.Cond
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
closeListener *listener.CloseListener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||||
return &ProxyWrapper{
|
return &ProxyWrapper{
|
||||||
WgeBPFProxy: WgeBPFProxy,
|
wgeBPFProxy: proxy,
|
||||||
|
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add turn conn: %w", err)
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
}
|
}
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.wgEndpointAddr = addr
|
p.wgRelayedEndpointAddr = addr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
return p.wgEndpointAddr
|
return p.wgRelayedEndpointAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||||
@@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
|
||||||
|
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
go p.proxyToLocal(p.ctx)
|
go p.proxyToLocal(p.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) Pause() {
|
func (p *ProxyWrapper) Pause() {
|
||||||
@@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = true
|
p.paused = true
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
|
||||||
|
p.wgEndpointCurrentUsedAddr = endpoint
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
func (e *ProxyWrapper) CloseConn() error {
|
func (p *ProxyWrapper) CloseConn() error {
|
||||||
if e.cancel == nil {
|
if p.cancel == nil {
|
||||||
return fmt.Errorf("proxy not started")
|
return fmt.Errorf("proxy not started")
|
||||||
}
|
}
|
||||||
|
|
||||||
e.cancel()
|
p.cancel()
|
||||||
|
|
||||||
e.closeListener.SetCloseListener(nil)
|
p.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
p.pausedCond.L.Lock()
|
||||||
return fmt.Errorf("close remote conn: %w", err)
|
p.paused = false
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
|
||||||
|
|
||||||
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||||
for {
|
for {
|
||||||
n, err := p.readFromRemote(ctx, buf)
|
n, err := p.readFromRemote(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
if p.paused {
|
for p.paused {
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.Wait()
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
}
|
}
|
||||||
p.closeListener.Notify()
|
p.closeListener.Notify()
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
package wgproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KernelFactory todo: check eBPF support on FreeBSD
|
|
||||||
type KernelFactory struct {
|
|
||||||
wgPort int
|
|
||||||
mtu uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
|
||||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
|
||||||
f := &KernelFactory{
|
|
||||||
wgPort: wgPort,
|
|
||||||
mtu: mtu,
|
|
||||||
}
|
|
||||||
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *KernelFactory) GetProxy() Proxy {
|
|
||||||
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,24 +3,25 @@ package wgproxy
|
|||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
)
|
)
|
||||||
|
|
||||||
type USPFactory struct {
|
type USPFactory struct {
|
||||||
bind *bind.ICEBind
|
bind proxyBind.Bind
|
||||||
|
mtu uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
|
||||||
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
||||||
f := &USPFactory{
|
f := &USPFactory{
|
||||||
bind: iceBind,
|
bind: bind,
|
||||||
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) GetProxy() Proxy {
|
func (w *USPFactory) GetProxy() Proxy {
|
||||||
return proxyBind.NewProxyBind(w.bind)
|
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
|
|||||||
@@ -11,6 +11,11 @@ type Proxy interface {
|
|||||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||||
Work() // Work start or resume the proxy
|
Work() // Work start or resume the proxy
|
||||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||||
|
|
||||||
|
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
|
||||||
|
//and rewrite the src address to the endpoint address.
|
||||||
|
//With this logic can avoid the package loss from relayed connections.
|
||||||
|
RedirectAs(endpoint *net.UDPAddr)
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
SetDisconnectListener(disconnected func())
|
SetDisconnectListener(disconnected func())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,54 +3,82 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"fmt"
|
||||||
"os"
|
"net"
|
||||||
"testing"
|
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
func seedProxies() ([]proxyInstance, error) {
|
||||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
pl := make([]proxyInstance, 0)
|
||||||
t.Skip("Skipping test as it requires root privileges")
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
pEbpf := proxyInstance{
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
name: "ebpf kernel proxy",
|
||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||||
}
|
wgPort: 51831,
|
||||||
}()
|
closeFn: ebpfProxy.Free,
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
proxy Proxy
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ebpf proxy",
|
|
||||||
proxy: &ebpf.ProxyWrapper{
|
|
||||||
WgeBPFProxy: ebpfProxy,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
pl = append(pl, pEbpf)
|
||||||
|
|
||||||
for _, tt := range tests {
|
pUDP := proxyInstance{
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
name: "udp kernel proxy",
|
||||||
relayedConn := newMockConn()
|
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
wgPort: 51832,
|
||||||
if err != nil {
|
closeFn: func() error { return nil },
|
||||||
t.Errorf("error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = relayedConn.Close()
|
|
||||||
if err := tt.proxy.CloseConn(); err != nil {
|
|
||||||
t.Errorf("error: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
pl = append(pl, pUDP)
|
||||||
|
return pl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||||
|
pl := make([]proxyInstance, 0)
|
||||||
|
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pEbpf := proxyInstance{
|
||||||
|
name: "ebpf kernel proxy",
|
||||||
|
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||||
|
wgPort: 51831,
|
||||||
|
closeFn: ebpfProxy.Free,
|
||||||
|
}
|
||||||
|
pl = append(pl, pEbpf)
|
||||||
|
|
||||||
|
pUDP := proxyInstance{
|
||||||
|
name: "udp kernel proxy",
|
||||||
|
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||||
|
wgPort: 51832,
|
||||||
|
closeFn: func() error { return nil },
|
||||||
|
}
|
||||||
|
pl = append(pl, pUDP)
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||||
|
endpointAddress := &net.UDPAddr{
|
||||||
|
IP: net.IPv4(10, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
pBind := proxyInstance{
|
||||||
|
name: "bind proxy",
|
||||||
|
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||||
|
endpointAddr: endpointAddress,
|
||||||
|
closeFn: func() error { return nil },
|
||||||
|
}
|
||||||
|
pl = append(pl, pBind)
|
||||||
|
|
||||||
|
return pl, nil
|
||||||
}
|
}
|
||||||
|
|||||||
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
|
)
|
||||||
|
|
||||||
|
func seedProxies() ([]proxyInstance, error) {
|
||||||
|
// todo extend with Bind proxy
|
||||||
|
pl := make([]proxyInstance, 0)
|
||||||
|
return pl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||||
|
pl := make([]proxyInstance, 0)
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||||
|
endpointAddress := &net.UDPAddr{
|
||||||
|
IP: net.IPv4(10, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
pBind := proxyInstance{
|
||||||
|
name: "bind proxy",
|
||||||
|
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||||
|
endpointAddr: endpointAddress,
|
||||||
|
closeFn: func() error { return nil },
|
||||||
|
}
|
||||||
|
pl = append(pl, pBind)
|
||||||
|
return pl, nil
|
||||||
|
}
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,12 +5,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
|
||||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type proxyInstance struct {
|
||||||
|
name string
|
||||||
|
proxy Proxy
|
||||||
|
wgPort int
|
||||||
|
endpointAddr *net.UDPAddr
|
||||||
|
closeFn func() error
|
||||||
|
}
|
||||||
|
|
||||||
type mocConn struct {
|
type mocConn struct {
|
||||||
closeChan chan struct{}
|
closeChan chan struct{}
|
||||||
closed bool
|
closed bool
|
||||||
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
|||||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
tests := []struct {
|
tests, err := seedProxyForProxyCloseByRemoteConn()
|
||||||
name string
|
if err != nil {
|
||||||
proxy Proxy
|
t.Fatalf("error: %v", err)
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "userspace proxy",
|
|
||||||
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
defer func() {
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
_ = relayedConn.Close()
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
}()
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
tests = append(tests, struct {
|
|
||||||
name string
|
|
||||||
proxy Proxy
|
|
||||||
}{
|
|
||||||
name: "ebpf proxy",
|
|
||||||
proxy: proxyWrapper,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
|
||||||
relayedConn := newMockConn()
|
relayedConn := newMockConn()
|
||||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
t.Errorf("error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProxyRedirect todo extend the proxies with Bind proxy
|
||||||
|
func TestProxyRedirect(t *testing.T) {
|
||||||
|
tests, err := seedProxies()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
|
||||||
|
if err := tt.closeFn(); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
msgHelloFromRelay := []byte("hello from relay")
|
||||||
|
msgRedirected := [][]byte{
|
||||||
|
[]byte("hello 1. to p2p"),
|
||||||
|
[]byte("hello 2. to p2p"),
|
||||||
|
[]byte("hello 3. to p2p"),
|
||||||
|
}
|
||||||
|
|
||||||
|
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: wgPort})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to listen on udp port: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayedServer, _ := net.ListenUDP("udp",
|
||||||
|
&net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = dummyWgListener.Close()
|
||||||
|
_ = relayedConn.Close()
|
||||||
|
_ = relayedServer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy.Work()
|
||||||
|
|
||||||
|
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
|
||||||
|
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dummyWgListener.Read(make([]byte, 1024))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(msgHelloFromRelay) {
|
||||||
|
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpointAddr := &net.UDPAddr{
|
||||||
|
IP: net.IPv4(192, 168, 0, 56),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
proxy.RedirectAs(p2pEndpointAddr)
|
||||||
|
|
||||||
|
for _, msg := range msgRedirected {
|
||||||
|
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(msgRedirected); i++ {
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, rAddr, err := dummyWgListener.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rAddr.String() != p2pEndpointAddr.String() {
|
||||||
|
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
|
||||||
|
}
|
||||||
|
if string(buf[:n]) != string(msgRedirected[i]) {
|
||||||
|
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package rawsocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||||
|
// Create a raw socket.
|
||||||
|
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||||
|
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the "lo" interface.
|
||||||
|
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the fwmark on the socket.
|
||||||
|
err = nbnet.SetSocketOpt(fd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the file descriptor to a PacketConn.
|
||||||
|
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||||
|
if file == nil {
|
||||||
|
return nil, fmt.Errorf("converting fd to file failed")
|
||||||
|
}
|
||||||
|
packetConn, err := net.FilePacketConn(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return packetConn, nil
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
|
|||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
mtu uint16
|
mtu uint16
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
ctx context.Context
|
srcFakerConn *SrcFaker
|
||||||
cancel context.CancelFunc
|
sendPkg func(data []byte) (int, error)
|
||||||
closeMu sync.Mutex
|
ctx context.Context
|
||||||
closed bool
|
cancel context.CancelFunc
|
||||||
|
closeMu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
paused bool
|
||||||
paused bool
|
pausedCond *sync.Cond
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
closeListener *listener.CloseListener
|
||||||
}
|
}
|
||||||
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
|||||||
p := &WGUDPProxy{
|
p := &WGUDPProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
|
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
|
|||||||
|
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.localConn = localConn
|
p.localConn = localConn
|
||||||
|
p.sendPkg = p.localConn.Write
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
p.sendPkg = p.localConn.Write
|
||||||
|
|
||||||
|
if p.srcFakerConn != nil {
|
||||||
|
if err := p.srcFakerConn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close src faker conn: %s", err)
|
||||||
|
}
|
||||||
|
p.srcFakerConn = nil
|
||||||
|
}
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
go p.proxyToRemote(p.ctx)
|
go p.proxyToRemote(p.ctx)
|
||||||
go p.proxyToLocal(p.ctx)
|
go p.proxyToLocal(p.ctx)
|
||||||
}
|
}
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pause pauses the proxy from receiving data from the remote peer
|
// Pause pauses the proxy from receiving data from the remote peer
|
||||||
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = true
|
p.paused = true
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectAs start to use the fake sourced raw socket as package sender
|
||||||
|
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
p.pausedCond.L.Lock()
|
||||||
|
defer func() {
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
p.paused = false
|
||||||
|
if p.srcFakerConn != nil {
|
||||||
|
if err := p.srcFakerConn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close src faker conn: %s", err)
|
||||||
|
}
|
||||||
|
p.srcFakerConn = nil
|
||||||
|
}
|
||||||
|
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create src faker conn: %s", err)
|
||||||
|
// fallback to continue without redirecting
|
||||||
|
p.paused = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.srcFakerConn = srcFakerConn
|
||||||
|
p.sendPkg = p.srcFakerConn.SendPkg
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGUDPProxy) close() error {
|
func (p *WGUDPProxy) close() error {
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
p.closeMu.Lock()
|
p.closeMu.Lock()
|
||||||
defer p.closeMu.Unlock()
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
|
|||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
var result *multierror.Error
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||||
}
|
}
|
||||||
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
|
|||||||
if err := p.localConn.Close(); err != nil {
|
if err := p.localConn.Close(); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.srcFakerConn != nil {
|
||||||
|
if err := p.srcFakerConn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return cerrors.FormatErrorOrNil(result)
|
return cerrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
if p.paused {
|
for p.paused {
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.Wait()
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
_, err = p.sendPkg(buf[:n])
|
||||||
_, err = p.localConn.Write(buf[:n])
|
p.pausedCond.L.Unlock()
|
||||||
p.pausedMu.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
|
|||||||
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serializeOpts = gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
localHostNetIPAddr = &net.IPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type SrcFaker struct {
|
||||||
|
srcAddr *net.UDPAddr
|
||||||
|
|
||||||
|
rawSocket net.PacketConn
|
||||||
|
ipH gopacket.SerializableLayer
|
||||||
|
udpH gopacket.SerializableLayer
|
||||||
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||||
|
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f := &SrcFaker{
|
||||||
|
srcAddr: srcAddr,
|
||||||
|
rawSocket: rawSocket,
|
||||||
|
ipH: ipH,
|
||||||
|
udpH: udpH,
|
||||||
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *SrcFaker) Close() error {
|
||||||
|
return f.rawSocket.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||||
|
defer func() {
|
||||||
|
if err := f.layerBuffer.Clear(); err != nil {
|
||||||
|
log.Errorf("failed to clear layer buffer: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := gopacket.Payload(data)
|
||||||
|
|
||||||
|
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||||
|
}
|
||||||
|
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||||
|
ipH := &layers.IPv4{
|
||||||
|
DstIP: net.ParseIP("127.0.0.1"),
|
||||||
|
SrcIP: srcAddr.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
udpH := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||||
|
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||||
|
}
|
||||||
|
|
||||||
|
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipH, udpH, nil
|
||||||
|
}
|
||||||
@@ -34,7 +34,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
for i, domain := range domains {
|
for i, domain := range domains {
|
||||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||||
if r.gpo {
|
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
singleDomain := []string{domain}
|
singleDomain := []string{domain}
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.gpo {
|
||||||
|
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||||
|
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||||
@@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
|||||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||||
}
|
}
|
||||||
@@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
|||||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||||
}
|
}
|
||||||
|
|||||||
5
client/internal/dns/server_js.go
Normal file
5
client/internal/dns/server_js.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||||
|
return &noopHostConfigurator{}, nil
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceViaMemory struct {
|
type ServiceViaMemory struct {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||||
|
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||||
systemdDbusResolvConfModeForeign = "foreign"
|
systemdDbusResolvConfModeForeign = "foreign"
|
||||||
|
|
||||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||||
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||||
|
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||||
|
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
|
|||||||
19
client/internal/dns/unclean_shutdown_js.go
Normal file
19
client/internal/dns/unclean_shutdown_js.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ShutdownState struct{}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
|
|||||||
78
client/internal/dnsfwd/cache.go
Normal file
78
client/internal/dnsfwd/cache.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
records map[string]*cacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheEntry struct {
|
||||||
|
ip4Addrs []netip.Addr
|
||||||
|
ip6Addrs []netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCache() *cache {
|
||||||
|
return &cache{
|
||||||
|
records: make(map[string]*cacheEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
entry, exists := c.records[normalizeDomain(domain)]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reqType {
|
||||||
|
case dns.TypeA:
|
||||||
|
return slices.Clone(entry.ip4Addrs), true
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
return slices.Clone(entry.ip6Addrs), true
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
norm := normalizeDomain(domain)
|
||||||
|
entry, exists := c.records[norm]
|
||||||
|
if !exists {
|
||||||
|
entry = &cacheEntry{}
|
||||||
|
c.records[norm] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reqType {
|
||||||
|
case dns.TypeA:
|
||||||
|
entry.ip4Addrs = slices.Clone(addrs)
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
entry.ip6Addrs = slices.Clone(addrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unset removes cached entries for the given domain and request type.
|
||||||
|
func (c *cache) unset(domain string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
delete(c.records, normalizeDomain(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||||
|
// lowercase and fully-qualified (with trailing dot).
|
||||||
|
func normalizeDomain(domain string) string {
|
||||||
|
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||||
|
return dns.Fqdn(strings.ToLower(domain))
|
||||||
|
}
|
||||||
86
client/internal/dnsfwd/cache_test.go
Normal file
86
client/internal/dnsfwd/cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||||
|
t.Helper()
|
||||||
|
a, err := netip.ParseAddr(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse addr %s: %v", s, err)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheNormalization(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
|
||||||
|
// Mixed case, without trailing dot
|
||||||
|
domainInput := "ExAmPlE.CoM"
|
||||||
|
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||||
|
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||||
|
|
||||||
|
// Lookup with lower, with trailing dot
|
||||||
|
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||||
|
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup with different casing again
|
||||||
|
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||||
|
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheSeparateTypes(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
|
||||||
|
domain := "test.local"
|
||||||
|
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||||
|
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||||
|
|
||||||
|
c.set(domain, 1 /* A */, ipv4)
|
||||||
|
c.set(domain, 28 /* AAAA */, ipv6)
|
||||||
|
|
||||||
|
got4, ok4 := c.get(domain, 1)
|
||||||
|
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||||
|
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||||
|
}
|
||||||
|
|
||||||
|
got6, ok6 := c.get(domain, 28)
|
||||||
|
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||||
|
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
domain := "clone.test"
|
||||||
|
|
||||||
|
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||||
|
c.set(domain, 1, src)
|
||||||
|
|
||||||
|
// Mutate source slice; cache should be unaffected
|
||||||
|
src[0] = mustAddr(t, "9.9.9.9")
|
||||||
|
|
||||||
|
got, ok := c.get(domain, 1)
|
||||||
|
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||||
|
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutate returned slice; internal cache should remain unchanged
|
||||||
|
got[0] = mustAddr(t, "4.4.4.4")
|
||||||
|
got2, ok2 := c.get(domain, 1)
|
||||||
|
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||||
|
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheMiss(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||||
|
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -46,6 +46,7 @@ type DNSForwarder struct {
|
|||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
firewall firewaller
|
firewall firewaller
|
||||||
resolver resolver
|
resolver resolver
|
||||||
|
cache *cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
|||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
resolver: net.DefaultResolver,
|
resolver: net.DefaultResolver,
|
||||||
|
cache: newCache(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
|||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
|
// remove cache entries for domains that no longer appear
|
||||||
|
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||||
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||||
|
// in the old list but not present in the new list.
|
||||||
|
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||||
|
if f.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newSet := make(map[string]struct{}, len(newEntries))
|
||||||
|
for _, e := range newEntries {
|
||||||
|
if e == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range oldEntries {
|
||||||
|
if e == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pattern := e.Domain.PunycodeString()
|
||||||
|
if _, ok := newSet[pattern]; !ok {
|
||||||
|
f.cache.unset(pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
|||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
func (f *DNSForwarder) handleDNSError(
|
||||||
|
ctx context.Context,
|
||||||
|
w dns.ResponseWriter,
|
||||||
|
question dns.Question,
|
||||||
|
resp *dns.Msg,
|
||||||
|
domain string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
// Default to SERVFAIL; override below when appropriate.
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
|
||||||
|
qType := question.Qtype
|
||||||
|
qTypeName := dns.TypeToString[qType]
|
||||||
|
|
||||||
|
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
if !errors.As(err, &dnsErr) {
|
||||||
switch {
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
case errors.As(err, &dnsErr):
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
if dnsErr.IsNotFound {
|
|
||||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
if dnsErr.IsNotFound {
|
||||||
} else {
|
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
}
|
}
|
||||||
default:
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upstream failed but we might have a cached answer—serve it if present.
|
||||||
|
if ips, ok := f.cache.get(domain, qType); ok {
|
||||||
|
if len(ips) > 0 {
|
||||||
|
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
resp.Rcode = dns.RcodeSuccess
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
|
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||||
|
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No cache. Log with or without the server field for more context.
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
// Write final failure response.
|
||||||
log.Errorf("failed to write failure DNS response: %v", err)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
|||||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensures that when the first query succeeds and populates the cache,
|
||||||
|
// a subsequent upstream failure still returns a successful response from cache.
|
||||||
|
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("1.2.3.4")
|
||||||
|
|
||||||
|
// First call resolves successfully and populates cache
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||||
|
Return([]netip.Addr{ip}, nil).Once()
|
||||||
|
|
||||||
|
// Second call fails upstream; forwarder should serve from cache
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||||
|
|
||||||
|
// First query: populate cache
|
||||||
|
q1 := &dns.Msg{}
|
||||||
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
|
w1 := &test.MockResponseWriter{}
|
||||||
|
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||||
|
require.NotNil(t, resp1)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
|
require.Len(t, resp1.Answer, 1)
|
||||||
|
|
||||||
|
// Second query: serve from cache after upstream failure
|
||||||
|
q2 := &dns.Msg{}
|
||||||
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
|
_ = forwarder.handleDNSQuery(w2, q2)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "expected response to be written")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
|
require.Len(t, writtenResp.Answer, 1)
|
||||||
|
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||||
|
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("ExAmPlE.CoM")
|
||||||
|
require.NoError(t, err)
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("9.8.7.6")
|
||||||
|
|
||||||
|
// Initial resolution with mixed case to populate cache
|
||||||
|
mixedQuery := "ExAmPlE.CoM"
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||||
|
Return([]netip.Addr{ip}, nil).Once()
|
||||||
|
|
||||||
|
q1 := &dns.Msg{}
|
||||||
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
|
w1 := &test.MockResponseWriter{}
|
||||||
|
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||||
|
require.NotNil(t, resp1)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
|
require.Len(t, resp1.Answer, 1)
|
||||||
|
|
||||||
|
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||||
|
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||||
|
|
||||||
|
q2 := &dns.Msg{}
|
||||||
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
|
_ = forwarder.handleDNSQuery(w2, q2)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
|
require.Len(t, writtenResp.Answer, 1)
|
||||||
|
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||||
// Test complex overlapping pattern scenarios
|
// Test complex overlapping pattern scenarios
|
||||||
mockFirewall := &MockFirewall{}
|
mockFirewall := &MockFirewall{}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -11,14 +12,18 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||||
|
listenPort uint16 = 5353
|
||||||
|
listenPortMu sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
dnsTTL = 60 //seconds
|
||||||
ListenPort = 5353
|
|
||||||
dnsTTL = 60 //seconds
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||||
@@ -37,6 +42,18 @@ type Manager struct {
|
|||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ListenPort() uint16 {
|
||||||
|
listenPortMu.RLock()
|
||||||
|
defer listenPortMu.RUnlock()
|
||||||
|
return listenPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetListenPort(port uint16) {
|
||||||
|
listenPortMu.Lock()
|
||||||
|
listenPort = port
|
||||||
|
listenPortMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
firewall: fw,
|
firewall: fw,
|
||||||
@@ -54,7 +71,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||||
go func() {
|
go func() {
|
||||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||||
// todo handle close error if it is exists
|
// todo handle close error if it is exists
|
||||||
@@ -94,7 +111,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
func (m *Manager) allowDNSFirewall() error {
|
func (m *Manager) allowDNSFirewall() error {
|
||||||
dport := &firewall.Port{
|
dport := &firewall.Port{
|
||||||
IsRange: false,
|
IsRange: false,
|
||||||
Values: []uint16{ListenPort},
|
Values: []uint16{ListenPort()},
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.firewall == nil {
|
if m.firewall == nil {
|
||||||
|
|||||||
@@ -198,6 +198,13 @@ type Engine struct {
|
|||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
|
// WireGuard interface monitor
|
||||||
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
|
wgIfaceMonitorWg sync.WaitGroup
|
||||||
|
|
||||||
|
// dns forwarder port
|
||||||
|
dnsFwdPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -240,6 +247,7 @@ func NewEngine(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
|
dnsFwdPort: dnsfwd.ListenPort(),
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
@@ -341,6 +349,9 @@ func (e *Engine) Stop() error {
|
|||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop WireGuard interface monitor and wait for it to exit
|
||||||
|
e.wgIfaceMonitorWg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,14 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return fmt.Errorf("initialize dns server: %w", err)
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iceCfg := icemaker.Config{
|
iceCfg := e.createICEConfig()
|
||||||
StunTurn: &e.stunTurn,
|
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
|
||||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
|
||||||
UDPMuxSrflx: e.udpMux,
|
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
|
||||||
}
|
|
||||||
|
|
||||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||||
e.connMgr.Start(e.ctx)
|
e.connMgr.Start(e.ctx)
|
||||||
@@ -477,6 +481,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
|
|
||||||
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
|
e.wgIfaceMonitorWg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer e.wgIfaceMonitorWg.Done()
|
||||||
|
|
||||||
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||||
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
|
e.restartEngine()
|
||||||
|
} else if err != nil {
|
||||||
|
log.Warnf("WireGuard interface monitor: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1064,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||||
|
|
||||||
// Ingress forward rules
|
// Ingress forward rules
|
||||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||||
@@ -1322,14 +1342,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
Addr: e.getRosenpassAddr(),
|
Addr: e.getRosenpassAddr(),
|
||||||
PermissiveMode: e.config.RosenpassPermissive,
|
PermissiveMode: e.config.RosenpassPermissive,
|
||||||
},
|
},
|
||||||
ICEConfig: icemaker.Config{
|
ICEConfig: e.createICEConfig(),
|
||||||
StunTurn: &e.stunTurn,
|
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
|
||||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
|
||||||
UDPMuxSrflx: e.udpMux,
|
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceDependencies := peer.ServiceDependencies{
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
@@ -1830,11 +1843,16 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
|||||||
func (e *Engine) updateDNSForwarder(
|
func (e *Engine) updateDNSForwarder(
|
||||||
enabled bool,
|
enabled bool,
|
||||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||||
|
forwarderPort uint16,
|
||||||
) {
|
) {
|
||||||
if e.config.DisableServerRoutes {
|
if e.config.DisableServerRoutes {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if forwarderPort > 0 {
|
||||||
|
dnsfwd.SetListenPort(forwarderPort)
|
||||||
|
}
|
||||||
|
|
||||||
if !enabled {
|
if !enabled {
|
||||||
if e.dnsForwardMgr == nil {
|
if e.dnsForwardMgr == nil {
|
||||||
return
|
return
|
||||||
@@ -1846,16 +1864,20 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(fwdEntries) > 0 {
|
if len(fwdEntries) > 0 {
|
||||||
if e.dnsForwardMgr == nil {
|
switch {
|
||||||
|
case e.dnsForwardMgr == nil:
|
||||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||||
|
|
||||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
log.Errorf("failed to start DNS forward: %v", err)
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||||
} else {
|
case e.dnsFwdPort != forwarderPort:
|
||||||
|
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||||
|
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||||
|
e.dnsFwdPort = forwarderPort
|
||||||
|
|
||||||
|
default:
|
||||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||||
}
|
}
|
||||||
} else if e.dnsForwardMgr != nil {
|
} else if e.dnsForwardMgr != nil {
|
||||||
@@ -1865,6 +1887,20 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
}
|
}
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||||
|
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||||
|
// stop and start the forwarder to apply the new port
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||||
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||||
|
|||||||
19
client/internal/engine_generic.go
Normal file
19
client/internal/engine_generic.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createICEConfig creates ICE configuration for non-WASM environments
|
||||||
|
func (e *Engine) createICEConfig() icemaker.Config {
|
||||||
|
return icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||||
|
UDPMuxSrflx: e.udpMux,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
}
|
||||||
18
client/internal/engine_js.go
Normal file
18
client/internal/engine_js.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createICEConfig creates ICE configuration for WASM environment.
|
||||||
|
func (e *Engine) createICEConfig() icemaker.Config {
|
||||||
|
cfg := icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
@@ -27,6 +27,10 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -42,10 +46,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
@@ -103,6 +105,10 @@ type MockWGIface struct {
|
|||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||||
return nil, fmt.Errorf("not implemented")
|
return nil, fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
@@ -1584,7 +1590,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
|||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
RemoveEndpointAddress(key string) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/ti-mo/netfilter"
|
"github.com/ti-mo/netfilter"
|
||||||
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultChannelSize = 100
|
const defaultChannelSize = 100
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
|||||||
|
|
||||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||||
// check dns collection
|
// check dns collection
|
||||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
12
client/internal/networkmonitor/check_change_js.go
Normal file
12
client/internal/networkmonitor/check_change_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
// No-op for WASM - network changes don't apply
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -28,10 +28,6 @@ import (
|
|||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServiceDependencies struct {
|
type ServiceDependencies struct {
|
||||||
StatusRecorder *Status
|
StatusRecorder *Status
|
||||||
Signaler *Signaler
|
Signaler *Signaler
|
||||||
@@ -117,6 +113,8 @@ type Conn struct {
|
|||||||
|
|
||||||
// debug purpose
|
// debug purpose
|
||||||
dumpState *stateDump
|
dumpState *stateDump
|
||||||
|
|
||||||
|
endpointUpdater *EndpointUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@@ -129,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
connLog := log.WithField("peer", config.Key)
|
connLog := log.WithField("peer", config.Key)
|
||||||
|
|
||||||
var conn = &Conn{
|
var conn = &Conn{
|
||||||
Log: connLog,
|
Log: connLog,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: services.StatusRecorder,
|
statusRecorder: services.StatusRecorder,
|
||||||
signaler: services.Signaler,
|
signaler: services.Signaler,
|
||||||
iFaceDiscover: services.IFaceDiscover,
|
iFaceDiscover: services.IFaceDiscover,
|
||||||
relayManager: services.RelayManager,
|
relayManager: services.RelayManager,
|
||||||
srWatcher: services.SrWatcher,
|
srWatcher: services.SrWatcher,
|
||||||
semaphore: services.Semaphore,
|
semaphore: services.Semaphore,
|
||||||
statusRelay: worker.NewAtomicStatus(),
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
statusICE: worker.NewAtomicStatus(),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||||
|
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@@ -172,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||||
if !isForceRelayed() {
|
if !isForceRelayed() {
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||||
@@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.wgProxyICE = nil
|
conn.wgProxyICE = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.removeWgPeer(); err != nil {
|
if err := conn.endpointUpdater.RemoveWgPeer(); err != nil {
|
||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||||
|
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||||
|
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||||
conn.handleConfigurationFailure(err, wgProxy)
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
conn.Log.Debugf("redirect packets from relayed conn to WireGuard")
|
||||||
|
conn.wgProxyRelay.RedirectAs(ep)
|
||||||
|
}
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.SetConnected()
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
@@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
conn.dumpState.SwitchToRelay()
|
conn.dumpState.SwitchToRelay()
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||||
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -418,10 +425,14 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
defer conn.wgWatcherWg.Done()
|
defer conn.wgWatcherWg.Done()
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
}()
|
}()
|
||||||
|
conn.wgProxyRelay.Work()
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
} else {
|
} else {
|
||||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
|
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||||
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||||
@@ -477,7 +488,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||||
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
@@ -514,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
if conn.currentConnPriority == conntype.Relay {
|
if conn.currentConnPriority == conntype.Relay {
|
||||||
conn.Log.Debugf("clean up WireGuard config")
|
conn.Log.Debugf("clean up WireGuard config")
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
|
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||||
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
@@ -545,17 +560,6 @@ func (conn *Conn) onGuardEvent() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
|
|
||||||
presharedKey := conn.presharedKey(remoteRPKey)
|
|
||||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
|
||||||
conn.config.WgConfig.RemoteKey,
|
|
||||||
conn.config.WgConfig.AllowedIps,
|
|
||||||
defaultWgKeepAlive,
|
|
||||||
addr,
|
|
||||||
presharedKey,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -698,10 +702,6 @@ func (conn *Conn) isICEActive() bool {
|
|||||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) removeWgPeer() error {
|
|
||||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||||
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||||
if wgProxy != nil {
|
if wgProxy != nil {
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onNewOffeChan := make(chan struct{})
|
onNewOfferChan := make(chan struct{})
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||||
onNewOffeChan <- struct{}{}
|
onNewOfferChan <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.OnRemoteOffer(OfferAnswer{
|
conn.OnRemoteOffer(OfferAnswer{
|
||||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-onNewOffeChan:
|
case <-onNewOfferChan:
|
||||||
// success
|
// success
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("expected to receive a new offer notification, but timed out")
|
t.Error("expected to receive a new offer notification, but timed out")
|
||||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onNewOffeChan := make(chan struct{})
|
onNewOfferChan := make(chan struct{})
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||||
onNewOffeChan <- struct{}{}
|
onNewOfferChan <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.OnRemoteAnswer(OfferAnswer{
|
conn.OnRemoteAnswer(OfferAnswer{
|
||||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-onNewOffeChan:
|
case <-onNewOfferChan:
|
||||||
// success
|
// success
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("expected to receive a new offer notification, but timed out")
|
t.Error("expected to receive a new offer notification, but timed out")
|
||||||
|
|||||||
105
client/internal/peer/endpoint.go
Normal file
105
client/internal/peer/endpoint.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
|
fallbackDelay = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type EndpointUpdater struct {
|
||||||
|
log *logrus.Entry
|
||||||
|
wgConfig WgConfig
|
||||||
|
initiator bool
|
||||||
|
|
||||||
|
// mu protects updateWireGuardPeer and cancelFunc
|
||||||
|
mu sync.Mutex
|
||||||
|
cancelFunc func()
|
||||||
|
updateWg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater {
|
||||||
|
return &EndpointUpdater{
|
||||||
|
log: log,
|
||||||
|
wgConfig: wgConfig,
|
||||||
|
initiator: initiator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||||
|
// The initiator immediately configures the endpoint, while the non-initiator
|
||||||
|
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||||
|
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
if e.initiator {
|
||||||
|
e.log.Debugf("configure up WireGuard as initiatr")
|
||||||
|
return e.updateWireGuardPeer(addr, presharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent to run new update while cancel the previous update
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||||
|
e.updateWg.Add(1)
|
||||||
|
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||||
|
|
||||||
|
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||||
|
return e.updateWireGuardPeer(nil, presharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||||
|
if e.cancelFunc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.cancelFunc()
|
||||||
|
e.cancelFunc = nil
|
||||||
|
e.updateWg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
|
||||||
|
func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) {
|
||||||
|
defer e.updateWg.Done()
|
||||||
|
t := time.NewTimer(fallbackDelay)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
e.mu.Lock()
|
||||||
|
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||||
|
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||||
|
}
|
||||||
|
e.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
return e.wgConfig.WgInterface.UpdatePeer(
|
||||||
|
e.wgConfig.RemoteKey,
|
||||||
|
e.wgConfig.AllowedIps,
|
||||||
|
defaultWgKeepAlive,
|
||||||
|
endpoint,
|
||||||
|
presharedKey,
|
||||||
|
)
|
||||||
|
}
|
||||||
20
client/internal/peer/guard/env.go
Normal file
20
client/internal/peer/guard/env.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package guard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetICEMonitorPeriod() time.Duration {
|
||||||
|
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
||||||
|
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultCandidatesMonitorPeriod
|
||||||
|
}
|
||||||
@@ -3,6 +3,8 @@ package guard
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,8 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
candidatesMonitorPeriod = 5 * time.Minute
|
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
||||||
candidateGatheringTimeout = 5 * time.Second
|
candidateGatheringTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type ICEMonitor struct {
|
type ICEMonitor struct {
|
||||||
@@ -23,16 +25,19 @@ type ICEMonitor struct {
|
|||||||
|
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
iceConfig icemaker.Config
|
iceConfig icemaker.Config
|
||||||
|
tickerPeriod time.Duration
|
||||||
|
|
||||||
currentCandidates []ice.Candidate
|
currentCandidatesAddress []string
|
||||||
candidatesMu sync.Mutex
|
candidatesMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
||||||
|
log.Debugf("prepare ICE monitor with period: %s", period)
|
||||||
cm := &ICEMonitor{
|
cm := &ICEMonitor{
|
||||||
ReconnectCh: make(chan struct{}, 1),
|
ReconnectCh: make(chan struct{}, 1),
|
||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
iceConfig: config,
|
iceConfig: config,
|
||||||
|
tickerPeriod: period,
|
||||||
}
|
}
|
||||||
return cm
|
return cm
|
||||||
}
|
}
|
||||||
@@ -44,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
// Initial check to populate the candidates for later comparison
|
||||||
|
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
||||||
|
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(cm.tickerPeriod)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -115,16 +125,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
|||||||
cm.candidatesMu.Lock()
|
cm.candidatesMu.Lock()
|
||||||
defer cm.candidatesMu.Unlock()
|
defer cm.candidatesMu.Unlock()
|
||||||
|
|
||||||
if len(cm.currentCandidates) != len(newCandidates) {
|
newAddresses := make([]string, len(newCandidates))
|
||||||
cm.currentCandidates = newCandidates
|
for i, c := range newCandidates {
|
||||||
|
newAddresses[i] = c.Address()
|
||||||
|
}
|
||||||
|
sort.Strings(newAddresses)
|
||||||
|
|
||||||
|
if len(cm.currentCandidatesAddress) != len(newAddresses) {
|
||||||
|
cm.currentCandidatesAddress = newAddresses
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, candidate := range cm.currentCandidates {
|
// Compare elements
|
||||||
if candidate.Address() != newCandidates[i].Address() {
|
if !slices.Equal(cm.currentCandidatesAddress, newAddresses) {
|
||||||
cm.currentCandidates = newCandidates
|
cm.currentCandidatesAddress = newAddresses
|
||||||
return true
|
return true
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
w.cancelIceMonitor = cancel
|
w.cancelIceMonitor = cancel
|
||||||
|
|
||||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||||
|
|||||||
@@ -44,13 +44,19 @@ type OfferAnswer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
ice *WorkerICE
|
ice *WorkerICE
|
||||||
relay *WorkerRelay
|
relay *WorkerRelay
|
||||||
onNewOfferListeners []*OfferListener
|
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||||
|
// and it will only keep the latest message if multiple offers are received in a short time
|
||||||
|
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||||
|
// and also to avoid processing old offers if multiple offers are received in a short time
|
||||||
|
// the listener will always process the latest offer
|
||||||
|
relayListener *AsyncOfferListener
|
||||||
|
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
l := NewOfferListener(offer)
|
h.relayListener = NewAsyncOfferListener(offer)
|
||||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
}
|
||||||
|
|
||||||
|
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
|
h.iceListener = offer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) Listen(ctx context.Context) {
|
func (h *Handshaker) Listen(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
// received confirmation from the remote peer -> ready to proceed
|
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
|
if h.relayListener != nil {
|
||||||
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.iceListener != nil {
|
||||||
|
h.iceListener(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.sendAnswer(); err != nil {
|
if err := h.sendAnswer(); err != nil {
|
||||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, listener := range h.onNewOfferListeners {
|
|
||||||
listener.Notify(&remoteOfferAnswer)
|
|
||||||
}
|
|
||||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
for _, listener := range h.onNewOfferListeners {
|
if h.relayListener != nil {
|
||||||
listener.Notify(&remoteOfferAnswer)
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.iceListener != nil {
|
||||||
|
h.iceListener(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
h.log.Infof("stop listening for remote offers and answers")
|
h.log.Infof("stop listening for remote offers and answers")
|
||||||
|
|||||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
|||||||
return oa.SessionID.String()
|
return oa.SessionID.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
type OfferListener struct {
|
type AsyncOfferListener struct {
|
||||||
fn callbackFunc
|
fn callbackFunc
|
||||||
running bool
|
running bool
|
||||||
latest *OfferAnswer
|
latest *OfferAnswer
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
||||||
return &OfferListener{
|
return &AsyncOfferListener{
|
||||||
fn: fn,
|
fn: fn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||||
o.mu.Lock()
|
o.mu.Lock()
|
||||||
defer o.mu.Unlock()
|
defer o.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
|||||||
runChan <- struct{}{}
|
runChan <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
hl := NewOfferListener(longRunningFn)
|
hl := NewAsyncOfferListener(longRunningFn)
|
||||||
|
|
||||||
hl.Notify(dummyOfferAnswer)
|
hl.Notify(dummyOfferAnswer)
|
||||||
hl.Notify(dummyOfferAnswer)
|
hl.Notify(dummyOfferAnswer)
|
||||||
|
|||||||
@@ -18,4 +18,5 @@ type WGIface interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
RemoveEndpointAddress(key string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
|||||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
if w.agentConnecting {
|
if w.agent != nil || w.agentConnecting {
|
||||||
w.log.Debugf("agent connection is in progress, skipping the offer")
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.agent != nil {
|
|
||||||
// backward compatibility with old clients that do not send session ID
|
// backward compatibility with old clients that do not send session ID
|
||||||
if remoteOfferAnswer.SessionID == nil {
|
if remoteOfferAnswer.SessionID == nil {
|
||||||
w.log.Debugf("agent already exists, skipping the offer")
|
w.log.Debugf("agent already exists, skipping the offer")
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
if err := w.agent.Close(); err != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionID, err := NewICESessionID()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to create new session ID: %s", err)
|
||||||
|
}
|
||||||
|
w.sessionID = sessionID
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
w.log.Debugf("recreate ICE agent")
|
if remoteOfferAnswer.SessionID != nil {
|
||||||
|
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
||||||
|
}
|
||||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.agent = agent
|
w.agent = agent
|
||||||
w.agentDialerCancel = dialerCancel
|
w.agentDialerCancel = dialerCancel
|
||||||
w.agentConnecting = true
|
w.agentConnecting = true
|
||||||
w.muxAgent.Unlock()
|
if remoteOfferAnswer.SessionID != nil {
|
||||||
|
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||||
|
} else {
|
||||||
|
w.remoteSessionID = ""
|
||||||
|
}
|
||||||
|
|
||||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.lastSuccess = time.Now()
|
w.lastSuccess = time.Now()
|
||||||
if remoteOfferAnswer.SessionID != nil {
|
|
||||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
|
||||||
}
|
|
||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
|
|
||||||
// todo: the potential problem is a race between the onConnectionStateChange
|
// todo: the potential problem is a race between the onConnectionStateChange
|
||||||
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
|
||||||
sessionID, err := NewICESessionID()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed to create new session ID: %s", err)
|
|
||||||
}
|
|
||||||
w.sessionID = sessionID
|
|
||||||
|
|
||||||
if w.agent == agent {
|
if w.agent == agent {
|
||||||
|
// consider to remove from here and move to the OnNewOffer
|
||||||
|
sessionID, err := NewICESessionID()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to create new session ID: %s", err)
|
||||||
|
}
|
||||||
|
w.sessionID = sessionID
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
|
w.remoteSessionID = ""
|
||||||
}
|
}
|
||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
}
|
}
|
||||||
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
|
w.closeAgent(agent, dialerCancel)
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected()
|
||||||
}
|
}
|
||||||
w.closeAgent(agent, dialerCancel)
|
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProbeResult holds the info about the result of a relay probe request
|
// ProbeResult holds the info about the result of a relay probe request
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dnsTimeout = 8 * time.Second
|
const dnsTimeout = 8 * time.Second
|
||||||
@@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||||
|
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||||
|
}
|
||||||
|
|
||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||||
return fmt.Errorf("setup routing: %w", err)
|
return fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Routing cleanup complete")
|
log.Info("Routing cleanup complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
nbnet.SetVPNInterfaceName("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -22,7 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const localSubnetsCacheTTL = 15 * time.Minute
|
const localSubnetsCacheTTL = 15 * time.Minute
|
||||||
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove hooks selectively
|
hooks.RemoveWriteHooks()
|
||||||
nbnet.RemoveDialerHooks()
|
hooks.RemoveCloseHooks()
|
||||||
nbnet.RemoveListenerHooks()
|
hooks.RemoveAddressRemoveHooks()
|
||||||
|
|
||||||
if err := r.refCounter.Flush(); err != nil {
|
if err := r.refCounter.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush route manager: %w", err)
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||||
return fmt.Errorf("adding route reference: %v", err)
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
}
|
}
|
||||||
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
afterHook := func(connID hooks.ConnectionID) error {
|
||||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
for _, ip := range initAddresses {
|
||||||
if err := beforeHook("init", ip); err != nil {
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := beforeHook("init", prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
hooks.AddWriteHook(beforeHook)
|
||||||
if ctx.Err() != nil {
|
hooks.AddCloseHook(afterHook)
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
for _, ip := range resolvedIPs {
|
|
||||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
||||||
return beforeHook(connID, ip.IP)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
|
||||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
type dialer interface {
|
||||||
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||||
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
index, err := net.InterfaceByName(wgInterface.Name())
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user