Compare commits

..

37 Commits

Author SHA1 Message Date
Pascal Fischer
ee8739760d remove status attr and unused function 2025-10-14 17:46:01 +02:00
Pascal Fischer
63b003f255 add extensive store metrics on all methods 2025-10-14 17:42:28 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
hakansa
d35a845dbd [management] sync all other peers on peer add/remove (#4614) 2025-10-09 21:18:00 +02:00
hakansa
4e03f708a4 fix dns forwarder port update (#4613)
fix dns forwarder port update (#4613)
2025-10-09 17:39:02 +03:00
Ashley
654aa9581d [client,gui] Update url_windows.go to offer arm64 executable download (#4586) 2025-10-08 21:27:32 +02:00
Zoltan Papp
9021bb512b [client] Recreate agent when receive new session id (#4564)
When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time.
In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started.
2025-10-08 17:14:24 +02:00
hakansa
768332820e [client] Implement DNS query caching in DNSForwarder (#4574)
implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails.

- Added a new cache component to store DNS query results by domain and query type
- Integrated cache storage after successful DNS resolutions
- Enhanced error handling to serve cached responses as fallback when upstream DNS fails
2025-10-08 16:54:27 +02:00
hakansa
229c65ffa1 Enhance showLoginURL to include connection status check and auto-close functionality (#4525) 2025-10-08 12:42:15 +02:00
Zoltan Papp
4d33567888 [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228)
* When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder.
2025-10-08 03:12:16 +02:00
Viktor Liu
88467883fc [management,signal] Remove ws-proxy read deadline (#4598) 2025-10-06 22:05:48 +02:00
Viktor Liu
954f40991f [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) 2025-10-06 21:22:19 +02:00
Maycon Santos
34341d95a9 Adjust signal port for websocket connections (#4594) 2025-10-06 15:22:02 -03:00
Viktor Liu
e7b5537dcc Add websocket paths including relay to nginx template (#4573) 2025-10-02 13:51:39 +02:00
hakansa
95794f53ce [client] fix Windows NRPT Policy Path (#4572)
[client] fix Windows NRPT Policy Path
2025-10-02 17:42:25 +07:00
hakansa
9bcd3ebed4 [management,client] Make DNS ForwarderPort Configurable & Change Well Known Port (#4479)
makes the DNS forwarder port configurable in the management and client components, while changing the well-known port from 5454 to 22054. The change includes version-aware port assignment to ensure backward compatibility.

- Adds a configurable `ForwarderPort` field to the DNS configuration protocol
- Implements version-based port computation that returns the new port (22054) only when all peers support version 0.59.0 or newer
- Updates the client to dynamically restart the DNS forwarder when the port changes
2025-10-02 01:02:10 +02:00
Maycon Santos
b85045e723 [misc] Update infra scripts with ws proxy for browser client (#4566)
* Update infra scripts with ws proxy for browser client

* add ws proxy to nginx tmpl
2025-10-02 00:52:54 +02:00
Viktor Liu
4d7e59f199 [client,signal,management] Adjust browser client ws proxy paths (#4565) 2025-10-02 00:10:47 +02:00
Viktor Liu
b5daec3b51 [client,signal,management] Add browser client support (#4415) 2025-10-01 20:10:11 +02:00
Zoltan Papp
5e1a40c33f [client] Order the list of candidates for proper comparison (#4561)
Order the list of candidates for proper comparison
2025-09-30 23:40:46 +02:00
Zoltan Papp
e8d301fdc9 [client] Fix/pkg loss (#3338)
The Relayed connection setup is optimistic. It does not have any confirmation of an established end-to-end connection. Peers start sending WireGuard handshake packets immediately after the successful offer-answer handshake.
Meanwhile, for successful P2P connection negotiation, we change the WireGuard endpoint address, but this change does not trigger new handshake initiation. Because the peer switched from Relayed connection to P2P, the packets from the Relay server are dropped and must wait for the next WireGuard handshake via P2P.

To avoid this scenario, the relayed WireGuard proxy no longer drops the packets. Instead, it rewrites the source address to the new P2P endpoint and continues forwarding the packets.

We still have one corner case: if the Relayed server negotiation chooses a server that has not been used before. In this case, one side of the peer connection will be slower to reach the Relay server, and the Relay server will drop the handshake packet.

If everything goes well we should see exactly 5 seconds improvements between the WireGuard configuration time and the handshake time.
2025-09-30 15:31:18 +02:00
hakansa
17bab881f7 [client] Add Windows DNS Policies To GPO Path Always (#4460)
[client] Add Windows DNS Policies To GPO Path Always (#4460)
2025-09-26 16:42:18 +07:00
Vlad
25ed58328a [management] fix network map dns filter (#4547) 2025-09-25 16:29:14 +02:00
hakansa
644ed4b934 [client] Add WireGuard interface lifecycle monitoring (#4370)
* [client] Add WireGuard interface lifecycle monitoring
2025-09-25 15:36:26 +07:00
Pascal Fischer
58faa341d2 [management] Add logs for update channel (#4527) 2025-09-23 12:06:10 +02:00
Viktor Liu
5853b5553c [client] Skip interface for route lookup if it doesn't exist (#4524) 2025-09-22 14:32:00 +02:00
Zoltan Papp
998fb30e1e [client] Check the client status in the earlier phase (#4509)
This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on:

  • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling
  • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability
  • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops

  Key Changes

  Core Status Improvements:
  - Added waitForReady optional field to StatusRequest proto (daemon.proto:190)
  - Enhanced status checking logic to detect client state changes earlier in the connection process
  - Improved handling of client permanent exit scenarios from retry loops

  Connection & Context Management:
  - Fixed context cancellation in management and signal client retry mechanisms
  - Added proper context propagation for Login operations
  - Enhanced gRPC connection handling with better timeout management

  Error Handling & Cleanup:
  - Moved feedback channels to upper layers for better separation of concerns
  - Improved error handling patterns throughout the client server implementation
  - Fixed synchronization issues and removed debug logging
2025-09-20 22:14:01 +02:00
Maycon Santos
e254b4cde5 [misc] Update SIGN_PIPE_VER to version 0.0.23 (#4521) 2025-09-20 10:24:04 +02:00
Zoltan Papp
ead1c618ba [client] Do not run up cmd if not needed in docker (#4508)
optimizes the NetBird client startup process by avoiding unnecessary login commands when the peer is already authenticated. The changes increase the default login timeout and expand the log message patterns used to detect successful authentication.

- Increased default login timeout from 1 to 5 seconds for more reliable authentication detection
- Enhanced log pattern matching to detect both registration and ready states
- Added extended regex support for more flexible pattern matching
2025-09-20 10:00:18 +02:00
Viktor Liu
55126f990c [client] Use native windows sock opts to avoid routing loops (#4314)
- Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed
 - Add methods to return the next best interface after the NetBird interface.
- Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method.
- Some refactoring to avoid import cycles
- Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var
2025-09-20 09:31:04 +02:00
Misha Bragin
90577682e4 Add a new product demo video (#4520) 2025-09-19 13:06:44 +02:00
Bethuel Mmbaga
dc30dcacce [management] Filter DNS records to include only peers to connect (#4517)
DNS record filtering to only include peers that a peer can connect to, reducing unnecessary DNS data in the peer's network map.

- Adds a new `filterZoneRecordsForPeers` function to filter DNS records based on peer connectivity
- Modifies `GetPeerNetworkMap` to use filtered DNS records instead of all records in the custom zone
- Includes comprehensive test coverage for the new filtering functionality
2025-09-18 18:57:07 +02:00
Diego Romar
2c87fa6236 [android] Add OnLoginSuccess callback to URLOpener interface (#4492)
The callback will be fired once login -> internal.Login
completes without errors
2025-09-18 15:07:42 +02:00
hakansa
ec8d83ade4 [client] [UI] Down & Up NetBird Async When Settings Updated
[client] [UI] Down & Up NetBird Async When Settings Updated
2025-09-18 18:13:29 +07:00
235 changed files with 7551 additions and 1902 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
skip: go.mod,go.sum skip: go.mod,go.sum
golangci: golangci:
strategy: strategy:

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.22" SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"

View File

@@ -0,0 +1,67 @@
name: Wasm
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:
name: "JS / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size
run: |
echo "Wasm build size:"
ls -lh netbird.wasm
SIZE=$(stat -c%s netbird.wasm)
SIZE_MB=$((SIZE / 1024 / 1024))
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

0
.gitmodules vendored Normal file
View File

View File

@@ -2,6 +2,18 @@ version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird - id: netbird
dir: client dir: client
binary: netbird binary: netbird
@@ -115,6 +127,11 @@ archives:
- builds: - builds:
- netbird - netbird
- netbird-static - netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>

View File

@@ -1,3 +1,4 @@
<div align="center"> <div align="center">
<br/> <br/>
<br/> <br/>
@@ -52,7 +53,7 @@
### Open Source Network Security in a Single Platform ### Open Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" /> https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video) ### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1" NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/client/net"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile

View File

@@ -33,6 +33,7 @@ type ErrListener interface {
// the backend want to show an url for the user // the backend want to show an url for the user
type URLOpener interface { type URLOpener interface {
Open(string) Open(string)
OnLoginSuccess()
} }
// Auth can register or login new client // Auth can register or login new client
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error { err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken) err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil return nil
} }

8
client/cmd/debug_js.go Normal file
View File

@@ -0,0 +1,8 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server. // DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3) ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() defer cancel()
return grpc.DialContext( return grpc.DialContext(

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
@@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{}) status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil { if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err) return fmt.Errorf("unable to get daemon status: %v", err)
} }

View File

@@ -23,23 +23,29 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started") var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance // Client manages a netbird embedded client instance.
type Client struct { type Client struct {
deviceName string deviceName string
config *profilemanager.Config config *profilemanager.Config
mu sync.Mutex mu sync.Mutex
cancel context.CancelFunc cancel context.CancelFunc
setupKey string setupKey string
jwtToken string
connect *internal.ConnectClient connect *internal.ConnectClient
} }
// Options configures a new Client // Options configures a new Client.
type Options struct { type Options struct {
// DeviceName is this peer's name in the network // DeviceName is this peer's name in the network
DeviceName string DeviceName string
// SetupKey is used for authentication // SetupKey is used for authentication
SetupKey string SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// PrivateKey is used for direct private key authentication
PrivateKey string
// ManagementURL overrides the default management server URL // ManagementURL overrides the default management server URL
ManagementURL string ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface // PreSharedKey is the pre-shared key for the WireGuard interface
@@ -58,8 +64,35 @@ type Options struct {
DisableClientRoutes bool DisableClientRoutes bool
} }
// New creates a new netbird embedded client // validateCredentials checks that exactly one credential type is provided
func (opts *Options) validateCredentials() error {
credentialsProvided := 0
if opts.SetupKey != "" {
credentialsProvided++
}
if opts.JWTToken != "" {
credentialsProvided++
}
if opts.PrivateKey != "" {
credentialsProvided++
}
if credentialsProvided == 0 {
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
}
if credentialsProvided > 1 {
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
}
return nil
}
// New creates a new netbird embedded client.
func New(opts Options) (*Client, error) { func New(opts Options) (*Client, error) {
if err := opts.validateCredentials(); err != nil {
return nil, err
}
if opts.LogOutput != nil { if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput) logrus.SetOutput(opts.LogOutput)
} }
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
return nil, fmt.Errorf("create config: %w", err) return nil, fmt.Errorf("create config: %w", err)
} }
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey
}
return &Client{ return &Client{
deviceName: opts.DeviceName, deviceName: opts.DeviceName,
setupKey: opts.SetupKey, setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config, config: config,
}, nil }, nil
} }
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err) return fmt.Errorf("login: %w", err)
} }
@@ -135,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan struct{}, 1) run := make(chan struct{})
clientErr := make(chan error, 1) clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
} }
} }
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network. // Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address) return nsnet.DialContext(ctx, network, address)
} }
// ListenTCP listens on the given address in the netbird network // ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) { func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet() nsnet, addr, err := c.getNet()
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr) return nsnet.ListenTCP(tcpAddr)
} }
// ListenUDP listens on the given address in the netbird network // ListenUDP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) { func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet() nsnet, addr, err := c.getNet()

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (

View File

@@ -4,15 +4,9 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt"
"net"
"os/user"
"runtime" "runtime"
"time" "time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -21,35 +15,9 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func WithCustomDialer() grpc.DialOption { // Backoff returns a backoff configuration for gRPC calls
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff { func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff() b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second b.MaxElapsedTime = 10 * time.Second
@@ -57,9 +25,12 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { // CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { // for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil { if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -71,14 +42,14 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
})) }))
} }
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
connCtx, connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,

View File

@@ -0,0 +1,44 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

13
client/grpc/dialer_js.go Normal file
View File

@@ -0,0 +1,13 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
}

View File

@@ -3,7 +3,7 @@ package bind
import ( import (
wireguard "golang.zx2c4.com/wireguard/conn" wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -1,5 +1,17 @@
package bind package bind
import wgConn "golang.zx2c4.com/wireguard/conn" import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -0,0 +1,7 @@
package bind
import "fmt"
var (
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
)

View File

@@ -1,6 +1,9 @@
//go:build !js
package bind package bind
import ( import (
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
@@ -17,14 +20,9 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct { type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
} }
@@ -42,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported. // use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct { type ICEBind struct {
*wgConn.StdNetBind *wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn udpmux.FilterFn filterFn udpmux.FilterFn
endpoints map[netip.Addr]net.Conn address wgaddr.Address
endpointsMu sync.Mutex mtu uint16
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
recvChan chan recvMessage
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one // new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{} closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool closed bool
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder activityRecorder *ActivityRecorder
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
} }
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn, filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
recvChan: make(chan recvMessage, 1),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
mtu: mtu,
address: address,
activityRecorder: NewActivityRecorder(), activityRecorder: NewActivityRecorder(),
} }
@@ -83,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
return ib return ib
} }
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false s.closed = false
s.closedChanMu.Lock() s.closedChanMu.Lock()
@@ -139,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
delete(b.endpoints, fakeIP) delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-b.closedChan:
return
case <-ctx.Done():
return
case b.recvChan <- recvMessage{ep, buf}:
}
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock() b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()] conn, ok := b.endpoints[ep.DstIP()]
@@ -271,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select { select {
case <-c.closedChan: case <-c.closedChan:
return 0, net.ErrClosed return 0, net.ErrClosed
case msg, ok := <-c.RecvChan: case msg, ok := <-c.recvChan:
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
} }

View File

@@ -0,0 +1,6 @@
package bind
type recvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

View File

@@ -0,0 +1,125 @@
package bind
import (
"context"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
)
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
// Do not limit to build only js, because we want to be able to run tests
type RelayBindJS struct {
*conn.StdNetBind
recvChan chan recvMessage
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
activityRecorder *ActivityRecorder
ctx context.Context
cancel context.CancelFunc
}
func NewRelayBindJS() *RelayBindJS {
return &RelayBindJS{
recvChan: make(chan recvMessage, 100),
endpoints: make(map[netip.Addr]net.Conn),
activityRecorder: NewActivityRecorder(),
}
}
// Open creates a receive function for handling relay packets in WASM.
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.ctx, s.cancel = context.WithCancel(context.Background())
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
select {
case <-s.ctx.Done():
return 0, net.ErrClosed
case msg, ok := <-s.recvChan:
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
func (s *RelayBindJS) Close() error {
if s.cancel == nil {
return nil
}
log.Debugf("close RelayBindJS")
s.cancel()
return nil
}
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-s.ctx.Done():
return
case <-ctx.Done():
return
case s.recvChan <- recvMessage{ep, buf}:
}
}
// Send forwards packets through the relay connection for WASM.
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
return err
}
}
return nil
}
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
}
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, ErrUDPMUXNotSupported
}
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}

View File

@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd //go:build linux || windows || freebsd || js || wasip1
package configurer package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows //go:build !windows && !js
package configurer package configurer

View File

@@ -0,0 +1,23 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
@@ -409,7 +470,7 @@ func toBytes(s string) (int64, error) {
} }
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark
} }
return 0 return 0

View File

@@ -15,8 +15,8 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
@@ -101,13 +101,8 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := udpmux.UniversalUDPMuxParams{ bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: udpConn, UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,

View File

@@ -1,9 +1,11 @@
package device package device
import ( import (
"errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
@@ -12,9 +14,15 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
}
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address wgaddr.Address address wgaddr.Address
@@ -22,7 +30,7 @@ type TunNetstackDevice struct {
key string key string
mtu uint16 mtu uint16
listenAddress string listenAddress string
iceBind *bind.ICEBind bind Bind
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
@@ -33,7 +41,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
key: key, key: key,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
iceBind: iceBind, bind: bind,
} }
} }
@@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.filteredDevice, t.filteredDevice,
t.iceBind, t.bind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
_ = tunIface.Close() _ = tunIface.Close()
@@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
udpMux, err := t.iceBind.GetICEMux() udpMux, err := t.bind.GetICEMux()
if err != nil { if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
return nil, err return nil, err
} }
t.udpMux = udpMux
if udpMux != nil {
t.udpMux = udpMux
}
log.Debugf("netstack device is ready to use") log.Debugf("netstack device is ready to use")
return udpMux, nil return udpMux, nil
} }

View File

@@ -0,0 +1,27 @@
package device
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestNewNetstackDevice(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
relayBind := bind.NewRelayBindJS()
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
cfgr, err := nsTun.Create()
if err != nil {
t.Fatalf("failed to create netstack device: %v", err)
}
if cfgr == nil {
t.Fatal("expected non-nil configurer")
}
}

View File

@@ -21,4 +21,5 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
} }

View File

@@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()

View File

@@ -0,0 +1,6 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -0,0 +1,41 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -0,0 +1,27 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}
return wgIface, nil
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build linux && !android
package iface package iface
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil return wgIFace, nil
} }
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -1,3 +1,5 @@
//go:build !js
package netstack package netstack
import ( import (

View File

@@ -0,0 +1,12 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -3,7 +3,7 @@
package udpmux package udpmux
import ( import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {

View File

@@ -16,28 +16,38 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type Bind interface {
Bind *bind.ICEBind SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
fakeNetIP *netip.AddrPort ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
wgBindEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
} }
func NewProxyBind(bind *bind.ICEBind) *ProxyBind { type ProxyBind struct {
bind Bind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
mtu uint16
}
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
p := &ProxyBind{ p := &ProxyBind{
Bind: bind, bind: bind,
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
} }
return p return p
@@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return nil return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{ return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) { func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -76,17 +86,21 @@ func (p *ProxyBind) Work() {
return return
} }
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once // Start the proxy only once
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
func (p *ProxyBind) Pause() { func (p *ProxyBind) Pause() {
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
} }
func (p *ProxyBind) CloseConn() error { func (p *ProxyBind) CloseConn() error {
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
} }
func (p *ProxyBind) close() error { func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}() }()
for { for {
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) buf := make([]byte, p.mtu)
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
msg := bind.RecvMessage{ p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
Endpoint: p.wgBindEndpoint, p.pausedCond.L.Unlock()
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
} }
} }

View File

@@ -6,9 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"syscall"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -18,15 +16,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const ( const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
return err return err
} }
p.rawConn, err = p.prepareSenderRawSocket() p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil { if err != nil {
return err return err
} }
@@ -214,57 +217,17 @@ generatePort:
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)
ipH := &layers.IPv4{ ipH := &layers.IPv4{
DstIP: localhost, DstIP: localHostNetIP,
SrcIP: localhost, SrcIP: endpointAddr.IP,
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
udpH := &layers.UDP{ udpH := &layers.UDP{
SrcPort: layers.UDPPort(port), SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort), DstPort: layers.UDPPort(p.localWGListenPort),
} }
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) return fmt.Errorf("serialize layers: %w", err)
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) return fmt.Errorf("write to raw conn: %w", err)
} }
return nil return nil

View File

@@ -18,41 +18,42 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex paused bool
paused bool pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{ return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy, wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr p.wgRelayedEndpointAddr = addr
return err return err
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgRelayedEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
func (p *ProxyWrapper) Pause() { func (p *ProxyWrapper) Pause() {
@@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
} }
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error { func (p *ProxyWrapper) CloseConn() error {
if e.cancel == nil { if p.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
} }
e.cancel() p.cancel()
e.closeListener.SetCloseListener(nil) p.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.pausedCond.L.Lock()
return fmt.Errorf("close remote conn: %w", err) p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.readFromRemote(ctx, buf) n, err := p.readFromRemote(ctx, buf)
if err != nil { if err != nil {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
p.closeListener.Notify() p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
} }
return 0, err return 0, err
} }

View File

@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
} }
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -1,31 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
mtu uint16
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -3,24 +3,25 @@ package wgproxy
import ( import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
) )
type USPFactory struct { type USPFactory struct {
bind *bind.ICEBind bind proxyBind.Bind
mtu uint16
} }
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy") log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{ f := &USPFactory{
bind: iceBind, bind: bind,
mtu: mtu,
} }
return f return f
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind) return proxyBind.NewProxyBind(w.bind, w.mtu)
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -11,6 +11,11 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
//and rewrite the src address to the endpoint address.
//With this logic can avoid the package loss from relayed connections.
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func()) SetDisconnectListener(disconnected func())
} }

View File

@@ -3,54 +3,82 @@
package wgproxy package wgproxy
import ( import (
"context" "fmt"
"os" "net"
"testing"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
func TestProxyCloseByRemoteConnEBPF(t *testing.T) { func seedProxies() ([]proxyInstance, error) {
if os.Getenv("GITHUB_ACTIONS") != "true" { pl := make([]proxyInstance, 0)
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
} }
defer func() { pEbpf := proxyInstance{
if err := ebpfProxy.Free(); err != nil { name: "ebpf kernel proxy",
t.Errorf("failed to free ebpf proxy: %s", err) proxy: ebpf.NewProxyWrapper(ebpfProxy),
} wgPort: 51831,
}() closeFn: ebpfProxy.Free,
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
} }
pl = append(pl, pEbpf)
for _, tt := range tests { pUDP := proxyInstance{
t.Run(tt.name, func(t *testing.T) { name: "udp kernel proxy",
relayedConn := newMockConn() proxy: udp.NewWGUDPProxy(51832, 1280),
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) wgPort: 51832,
if err != nil { closeFn: func() error { return nil },
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
} }
pl = append(pl, pUDP)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
} }

View File

@@ -0,0 +1,39 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,5 +1,3 @@
//go:build linux
package wgproxy package wgproxy
import ( import (
@@ -7,12 +5,9 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
os.Exit(code) os.Exit(code)
} }
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct { type mocConn struct {
closeChan chan struct{} closeChan chan struct{}
closed bool closed bool
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) { func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tests := []struct { tests, err := seedProxyForProxyCloseByRemoteConn()
name string if err != nil {
proxy Proxy t.Fatalf("error: %v", err)
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
},
} }
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) defer func() {
if err := ebpfProxy.Listen(); err != nil { _ = relayedConn.Close()
t.Fatalf("failed to initialize ebpf proxy: %s", err) }()
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn() relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}) })
} }
} }
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

View File

@@ -0,0 +1,50 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/client/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux && !android
package udp package udp
import ( import (
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
localWGListenPort int localWGListenPort int
mtu uint16 mtu uint16
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
ctx context.Context srcFakerConn *SrcFaker
cancel context.CancelFunc sendPkg func(data []byte) (int, error)
closeMu sync.Mutex ctx context.Context
closed bool cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex paused bool
paused bool pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu, mtu: mtu,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
return p return p
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn p.remoteConn = remoteConn
return err return err
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.sendPkg = p.localConn.Write
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToRemote(p.ctx) go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
// Pause pauses the proxy from receiving data from the remote peer // Pause pauses the proxy from receiving data from the remote peer
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
} }
func (p *WGUDPProxy) close() error { func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
p.cancel() p.cancel()
var result *multierror.Error p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
_, err = p.sendPkg(buf[:n])
_, err = p.localConn.Write(buf[:n]) p.pausedCond.L.Unlock()
p.pausedMu.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -0,0 +1,101 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -34,7 +34,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )

View File

@@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains { for i, domain := range domains {
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
if r.gpo { gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
}
singleDomain := []string{domain} singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
}
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
}
} }
log.Debugf("added NRPT entry for domain: %s", domain) log.Debugf("added NRPT entry for domain: %s", domain)
@@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
} }
@@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
} }

View File

@@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (hostManager, error) {
return &noopHostConfigurator{}, nil
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type ServiceViaMemory struct { type ServiceViaMemory struct {

View File

@@ -31,6 +31,7 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err) log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string

View File

@@ -0,0 +1,19 @@
package dns
import (
"context"
)
type ShutdownState struct{}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
return nil
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
type upstreamResolver struct { type upstreamResolver struct {

View File

@@ -0,0 +1,78 @@
package dnsfwd
import (
"net/netip"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
)
type cache struct {
mu sync.RWMutex
records map[string]*cacheEntry
}
type cacheEntry struct {
ip4Addrs []netip.Addr
ip6Addrs []netip.Addr
}
func newCache() *cache {
return &cache{
records: make(map[string]*cacheEntry),
}
}
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.records[normalizeDomain(domain)]
if !exists {
return nil, false
}
switch reqType {
case dns.TypeA:
return slices.Clone(entry.ip4Addrs), true
case dns.TypeAAAA:
return slices.Clone(entry.ip6Addrs), true
default:
return nil, false
}
}
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
norm := normalizeDomain(domain)
entry, exists := c.records[norm]
if !exists {
entry = &cacheEntry{}
c.records[norm] = entry
}
switch reqType {
case dns.TypeA:
entry.ip4Addrs = slices.Clone(addrs)
case dns.TypeAAAA:
entry.ip6Addrs = slices.Clone(addrs)
}
}
// unset removes cached entries for the given domain and request type.
func (c *cache) unset(domain string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.records, normalizeDomain(domain))
}
// normalizeDomain converts an input domain into a canonical form used as cache key:
// lowercase and fully-qualified (with trailing dot).
func normalizeDomain(domain string) string {
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
return dns.Fqdn(strings.ToLower(domain))
}

View File

@@ -0,0 +1,86 @@
package dnsfwd
import (
"net/netip"
"testing"
)
func mustAddr(t *testing.T, s string) netip.Addr {
t.Helper()
a, err := netip.ParseAddr(s)
if err != nil {
t.Fatalf("parse addr %s: %v", s, err)
}
return a
}
func TestCacheNormalization(t *testing.T) {
c := newCache()
// Mixed case, without trailing dot
domainInput := "ExAmPlE.CoM"
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
// Lookup with lower, with trailing dot
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
}
// Lookup with different casing again
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
}
}
func TestCacheSeparateTypes(t *testing.T) {
c := newCache()
domain := "test.local"
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
c.set(domain, 1 /* A */, ipv4)
c.set(domain, 28 /* AAAA */, ipv6)
got4, ok4 := c.get(domain, 1)
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
}
got6, ok6 := c.get(domain, 28)
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
}
}
func TestCacheCloneOnGetAndSet(t *testing.T) {
c := newCache()
domain := "clone.test"
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
c.set(domain, 1, src)
// Mutate source slice; cache should be unaffected
src[0] = mustAddr(t, "9.9.9.9")
got, ok := c.get(domain, 1)
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
}
// Mutate returned slice; internal cache should remain unchanged
got[0] = mustAddr(t, "4.4.4.4")
got2, ok2 := c.get(domain, 1)
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
}
}
func TestCacheMiss(t *testing.T) {
c := newCache()
if got, ok := c.get("missing.example", 1); ok || got != nil {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
}
}

View File

@@ -46,6 +46,7 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewaller firewall firewaller
resolver resolver resolver resolver
cache *cache
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
cache: newCache(),
} }
} }
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
// remove cache entries for domains that no longer appear
f.removeStaleCacheEntries(f.fwdEntries, entries)
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries)) log.Debugf("Updated DNS forwarder with %d domains", len(entries))
} }
// removeStaleCacheEntries unsets cache items for domains that were present
// in the old list but not present in the new list.
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
if f.cache == nil {
return
}
newSet := make(map[string]struct{}, len(newEntries))
for _, e := range newEntries {
if e == nil {
continue
}
newSet[e.Domain.PunycodeString()] = struct{}{}
}
for _, e := range oldEntries {
if e == nil {
continue
}
pattern := e.Domain.PunycodeString()
if _, ok := newSet[pattern]; !ok {
f.cache.unset(pattern)
}
}
}
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
var result *multierror.Error var result *multierror.Error
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp return resp
} }
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
resp.Rcode = dns.RcodeSuccess resp.Rcode = dns.RcodeSuccess
} }
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(
ctx context.Context,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
switch { log.Warnf(errResolveFailed, domain, err)
case errors.As(err, &dnsErr): if writeErr := w.WriteMsg(resp); writeErr != nil {
resp.Rcode = dns.RcodeServerFailure log.Errorf("failed to write failure DNS response: %v", writeErr)
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
} }
return
}
if dnsErr.Server != "" { // NotFound: set NXDOMAIN / appropriate code via helper.
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) if dnsErr.IsNotFound {
} else { f.setResponseCodeForNotFound(ctx, resp, domain, qType)
log.Warnf(errResolveFailed, domain, err) if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
default: f.cache.set(domain, question.Qtype, nil)
resp.Rcode = dns.RcodeServerFailure return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }
if err := w.WriteMsg(resp); err != nil { // Write final failure response.
log.Errorf("failed to write failure DNS response: %v", err) if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
} }

View File

@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
} }
// Ensures that when the first query succeeds and populates the cache,
// a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("1.2.3.4")
// First call resolves successfully and populates cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{ip}, nil).Once()
// Second call fails upstream; forwarder should serve from cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
// First query: populate cache
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
// Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("ExAmPlE.CoM")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("9.8.7.6")
// Initial resolution with mixed case to populate cache
mixedQuery := "ExAmPlE.CoM"
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
Return([]netip.Addr{ip}, nil).Once()
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Subsequent query without dot and upper case should hit cache even if upstream fails
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios // Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{} mockFirewall := &MockFirewall{}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -11,14 +12,18 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
listenPort uint16 = 5353
listenPortMu sync.RWMutex
) )
const ( const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also dnsTTL = 60 //seconds
ListenPort = 5353
dnsTTL = 60 //seconds
) )
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
@@ -37,6 +42,18 @@ type Manager struct {
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
} }
func ListenPort() uint16 {
listenPortMu.RLock()
defer listenPortMu.RUnlock()
return listenPort
}
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{ return &Manager{
firewall: fw, firewall: fw,
@@ -54,7 +71,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err return err
} }
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() { go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil { if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists // todo handle close error if it is exists
@@ -94,7 +111,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error { func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{ dport := &firewall.Port{
IsRange: false, IsRange: false,
Values: []uint16{ListenPort}, Values: []uint16{ListenPort()},
} }
if m.firewall == nil { if m.firewall == nil {

View File

@@ -198,6 +198,13 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -240,6 +247,7 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(),
} }
sm := profilemanager.NewServiceManager("") sm := profilemanager.NewServiceManager("")
@@ -341,6 +349,9 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
return nil return nil
} }
@@ -457,14 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
iceCfg := icemaker.Config{ iceCfg := e.createICEConfig()
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr.Start(e.ctx) e.connMgr.Start(e.ctx)
@@ -477,6 +481,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
}()
return nil return nil
} }
@@ -1064,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
// Ingress forward rules // Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1322,14 +1342,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
Addr: e.getRosenpassAddr(), Addr: e.getRosenpassAddr(),
PermissiveMode: e.config.RosenpassPermissive, PermissiveMode: e.config.RosenpassPermissive,
}, },
ICEConfig: icemaker.Config{ ICEConfig: e.createICEConfig(),
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
} }
serviceDependencies := peer.ServiceDependencies{ serviceDependencies := peer.ServiceDependencies{
@@ -1830,11 +1843,16 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder( func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) { ) {
if e.config.DisableServerRoutes { if e.config.DisableServerRoutes {
return return
} }
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
return return
@@ -1846,16 +1864,20 @@ func (e *Engine) updateDNSForwarder(
} }
if len(fwdEntries) > 0 { if len(fwdEntries) > 0 {
if e.dnsForwardMgr == nil { switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err) log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
log.Infof("started domain router service with %d entries", len(fwdEntries)) log.Infof("started domain router service with %d entries", len(fwdEntries))
} else { case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default:
e.dnsForwardMgr.UpdateDomains(fwdEntries) e.dnsForwardMgr.UpdateDomains(fwdEntries)
} }
} else if e.dnsForwardMgr != nil { } else if e.dnsForwardMgr != nil {
@@ -1865,6 +1887,20 @@ func (e *Engine) updateDNSForwarder(
} }
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
}
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} }
func (e *Engine) GetNet() (*netstack.Net, error) { func (e *Engine) GetNet() (*netstack.Net, error) {

View File

@@ -0,0 +1,19 @@
//go:build !js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for non-WASM environments
func (e *Engine) createICEConfig() icemaker.Config {
return icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
}

View File

@@ -0,0 +1,18 @@
//go:build js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for WASM environment.
func (e *Engine) createICEConfig() icemaker.Config {
cfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
return cfg
}

View File

@@ -27,6 +27,10 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
@@ -42,10 +46,8 @@ import (
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@@ -103,6 +105,10 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented") return nil, fmt.Errorf("not implemented")
} }
@@ -1584,7 +1590,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -28,6 +28,7 @@ type wgIfaceBase interface {
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error

View File

@@ -14,7 +14,7 @@ import (
"github.com/ti-mo/netfilter" "github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const defaultChannelSize = 100 const defaultChannelSize = 100

View File

@@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection // check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) { if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
return false return false
} }

View File

@@ -0,0 +1,12 @@
package networkmonitor
import (
"context"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
// No-op for WASM - network changes don't apply
return nil
}

View File

@@ -28,10 +28,6 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
const (
defaultWgKeepAlive = 25 * time.Second
)
type ServiceDependencies struct { type ServiceDependencies struct {
StatusRecorder *Status StatusRecorder *Status
Signaler *Signaler Signaler *Signaler
@@ -117,6 +113,8 @@ type Conn struct {
// debug purpose // debug purpose
dumpState *stateDump dumpState *stateDump
endpointUpdater *EndpointUpdater
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@@ -129,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
Log: connLog, Log: connLog,
config: config, config: config,
statusRecorder: services.StatusRecorder, statusRecorder: services.StatusRecorder,
signaler: services.Signaler, signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover, iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager, relayManager: services.RelayManager,
srWatcher: services.SrWatcher, srWatcher: services.SrWatcher,
semaphore: services.Semaphore, semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(), statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
} }
return conn, nil return conn, nil
@@ -172,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() { if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
} }
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
@@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
if err := conn.removeWgPeer(); err != nil { if err := conn.endpointUpdater.RemoveWgPeer(); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
@@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
wgProxy.Work() wgProxy.Work()
} }
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy) conn.handleConfigurationFailure(err, wgProxy)
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyRelay != nil {
conn.Log.Debugf("redirect packets from relayed conn to WireGuard")
conn.wgProxyRelay.RedirectAs(ep)
}
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
@@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() {
conn.dumpState.SwitchToRelay() conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
conn.Log.Errorf("failed to switch to relay conn: %v", err) conn.Log.Errorf("failed to switch to relay conn: %v", err)
} }
@@ -418,10 +425,14 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done() defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
}() }()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -477,7 +488,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
} }
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { presharedKey := conn.presharedKey(rci.rosenpassPubKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err) conn.Log.Warnf("Failed to close relay connection: %v", err)
} }
@@ -514,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -545,17 +560,6 @@ func (conn *Conn) onGuardEvent() {
} }
} }
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
presharedKey := conn.presharedKey(remoteRPKey)
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
presharedKey,
)
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -698,10 +702,6 @@ func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
} }
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.Log.Warnf("Failed to update wg peer configuration: %v", err) conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil { if wgProxy != nil {

View File

@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
onNewOffeChan := make(chan struct{}) onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{} onNewOfferChan <- struct{}{}
}) })
conn.OnRemoteOffer(OfferAnswer{ conn.OnRemoteOffer(OfferAnswer{
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOffeChan: case <-onNewOfferChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
onNewOffeChan := make(chan struct{}) onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{} onNewOfferChan <- struct{}{}
}) })
conn.OnRemoteAnswer(OfferAnswer{ conn.OnRemoteAnswer(OfferAnswer{
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOffeChan: case <-onNewOfferChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")

View File

@@ -0,0 +1,105 @@
package peer
import (
"context"
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
defaultWgKeepAlive = 25 * time.Second
fallbackDelay = 5 * time.Second
)
type EndpointUpdater struct {
log *logrus.Entry
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
}
func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater {
return &EndpointUpdater{
log: log,
wgConfig: wgConfig,
initiator: initiator,
}
}
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.initiator {
e.log.Debugf("configure up WireGuard as initiatr")
return e.updateWireGuardPeer(addr, presharedKey)
}
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
e.log.Debugf("configure up WireGuard and wait for handshake")
return e.updateWireGuardPeer(nil, presharedKey)
}
func (e *EndpointUpdater) RemoveWgPeer() error {
e.mu.Lock()
defer e.mu.Unlock()
e.waitForCloseTheDelayedUpdate()
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil {
return
}
e.cancelFunc()
e.cancelFunc = nil
e.updateWg.Wait()
}
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) {
defer e.updateWg.Done()
t := time.NewTimer(fallbackDelay)
defer t.Stop()
select {
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}
func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error {
return e.wgConfig.WgInterface.UpdatePeer(
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
presharedKey,
)
}

View File

@@ -0,0 +1,20 @@
package guard
import (
"os"
"strconv"
"time"
)
const (
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
)
func GetICEMonitorPeriod() time.Duration {
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return defaultCandidatesMonitorPeriod
}

View File

@@ -3,6 +3,8 @@ package guard
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"sort"
"sync" "sync"
"time" "time"
@@ -14,8 +16,8 @@ import (
) )
const ( const (
candidatesMonitorPeriod = 5 * time.Minute defaultCandidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second candidateGatheringTimeout = 5 * time.Second
) )
type ICEMonitor struct { type ICEMonitor struct {
@@ -23,16 +25,19 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config iceConfig icemaker.Config
tickerPeriod time.Duration
currentCandidates []ice.Candidate currentCandidatesAddress []string
candidatesMu sync.Mutex candidatesMu sync.Mutex
} }
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
log.Debugf("prepare ICE monitor with period: %s", period)
cm := &ICEMonitor{ cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1), ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
iceConfig: config, iceConfig: config,
tickerPeriod: period,
} }
return cm return cm
} }
@@ -44,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
return return
} }
ticker := time.NewTicker(candidatesMonitorPeriod) // Initial check to populate the candidates for later comparison
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
log.Warnf("Failed to check initial ICE candidates: %v", err)
}
ticker := time.NewTicker(cm.tickerPeriod)
defer ticker.Stop() defer ticker.Stop()
for { for {
@@ -115,16 +125,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock() cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock() defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) { newAddresses := make([]string, len(newCandidates))
cm.currentCandidates = newCandidates for i, c := range newCandidates {
newAddresses[i] = c.Address()
}
sort.Strings(newAddresses)
if len(cm.currentCandidatesAddress) != len(newAddresses) {
cm.currentCandidatesAddress = newAddresses
return true return true
} }
for i, candidate := range cm.currentCandidates { // Compare elements
if candidate.Address() != newCandidates[i].Address() { if !slices.Equal(cm.currentCandidatesAddress, newAddresses) {
cm.currentCandidates = newCandidates cm.currentCandidatesAddress = newAddresses
return true return true
}
} }
return false return false

View File

@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged) go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -44,13 +44,19 @@ type OfferAnswer struct {
} }
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
onNewOfferListeners []*OfferListener // relayListener is not blocking because the listener is using a goroutine to process the messages
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
// and also to avoid processing old offers if multiple offers are received in a short time
// the listener will always process the latest offer
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
} }
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
l := NewOfferListener(offer) h.relayListener = NewAsyncOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l) }
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.iceListener = offer
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
if err := h.sendAnswer(); err != nil { if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err) h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue continue
} }
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners { if h.relayListener != nil {
listener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
return oa.SessionID.String() return oa.SessionID.String()
} }
type OfferListener struct { type AsyncOfferListener struct {
fn callbackFunc fn callbackFunc
running bool running bool
latest *OfferAnswer latest *OfferAnswer
mu sync.Mutex mu sync.Mutex
} }
func NewOfferListener(fn callbackFunc) *OfferListener { func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
return &OfferListener{ return &AsyncOfferListener{
fn: fn, fn: fn,
} }
} }
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock() o.mu.Lock()
defer o.mu.Unlock() defer o.mu.Unlock()

View File

@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
runChan <- struct{}{} runChan <- struct{}{}
} }
hl := NewOfferListener(longRunningFn) hl := NewAsyncOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)

View File

@@ -18,4 +18,5 @@ type WGIface interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
RemoveEndpointAddress(key string) error
} }

View File

@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agentConnecting { if w.agent != nil || w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
// backward compatibility with old clients that do not send session ID // backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil { if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return return
} }
if w.remoteSessionID == *remoteOfferAnswer.SessionID { if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return return
} }
w.log.Debugf("agent already exists, recreate the connection") w.log.Debugf("agent already exists, recreate the connection")
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil { if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
} }
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = icemaker.CandidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
w.log.Debugf("recreate ICE agent") if remoteOfferAnswer.SessionID != nil {
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
}
dialerCtx, dialerCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return return
} }
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel w.agentDialerCancel = dialerCancel
w.agentConnecting = true w.agentConnecting = true
w.muxAgent.Unlock() if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
} else {
w.remoteSessionID = ""
}
go w.connect(dialerCtx, agent, remoteOfferAnswer) go w.connect(dialerCtx, agent, remoteOfferAnswer)
} }
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.muxAgent.Lock() w.muxAgent.Lock()
w.agentConnecting = false w.agentConnecting = false
w.lastSuccess = time.Now() w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock() w.muxAgent.Unlock()
// todo: the potential problem is a race between the onConnectionStateChange // todo: the potential problem is a race between the onConnectionStateChange
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
} }
w.muxAgent.Lock() w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent { if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
w.agentConnecting = false w.agentConnecting = false
w.remoteSessionID = ""
} }
w.muxAgent.Unlock() w.muxAgent.Unlock()
} }
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority // notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()
} }
w.closeAgent(agent, dialerCancel)
default: default:
return return
} }

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
// ProbeResult holds the info about the result of a relay probe request // ProbeResult holds the info about the result of a relay probe request

View File

@@ -24,8 +24,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
const dnsTimeout = 8 * time.Second const dnsTimeout = 8 * time.Second
@@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()

View File

@@ -36,9 +36,9 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier) sysOps := systemops.NewSysOps(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
}
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
return nil return nil
} }
if err := m.sysOps.CleanupRouting(nil); err != nil { if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
log.Warnf("Failed cleaning up routing: %v", err) log.Warnf("Failed cleaning up routing: %v", err)
} }
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
return fmt.Errorf("setup routing: %w", err) return fmt.Errorf("setup routing: %w", err)
} }
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
if err := m.sysOps.CleanupRouting(stateManager); err != nil { if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
log.Errorf("Error cleaning up routing: %v", err) log.Errorf("Error cleaning up routing: %v", err)
} else { } else {
log.Info("Routing cleanup complete") log.Info("Routing cleanup complete")
} }
if runtime.GOOS == "windows" {
nbnet.SetVPNInterfaceName("")
}
} }
m.mux.Lock() m.mux.Lock()

View File

@@ -12,11 +12,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
return nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
return nil return nil
} }

View File

@@ -3,7 +3,6 @@
package systemops package systemops
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -22,7 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/client/net/hooks"
) )
const localSubnetsCacheTTL = 15 * time.Minute const localSubnetsCacheTTL = 15 * time.Minute
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil return nil
} }
// TODO: Remove hooks selectively hooks.RemoveWriteHooks()
nbnet.RemoveDialerHooks() hooks.RemoveCloseHooks()
nbnet.RemoveListenerHooks() hooks.RemoveAddressRemoveHooks()
if err := r.refCounter.Flush(); err != nil { if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err) return fmt.Errorf("flush route manager: %w", err)
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err) return fmt.Errorf("adding route reference: %v", err)
} }
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil return nil
} }
afterHook := func(connID nbnet.ConnectionID) error { afterHook := func(connID hooks.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil { if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err) return fmt.Errorf("remove route reference: %w", err)
} }
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
var merr *multierror.Error var merr *multierror.Error
for _, ip := range initAddresses { for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil { prefix, err := util.GetPrefixFromIP(ip)
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
continue
}
if err := beforeHook("init", prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
} }
} }
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { hooks.AddWriteHook(beforeHook)
if ctx.Err() != nil { hooks.AddCloseHook(afterHook)
return ctx.Err()
}
var merr *multierror.Error hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
for _, ip := range resolvedIPs {
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(merr)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil { if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err) return fmt.Errorf("remove route reference: %w", err)
} }

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
) )
type dialer interface { type dialer interface {
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil) advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
}) })
intf, err := net.InterfaceByName(wgInterface.Name()) intf, err := net.InterfaceByName(wgInterface.Name())
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil) advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
}) })
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil) advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
}) })
index, err := net.InterfaceByName(wgInterface.Name()) index, err := net.InterfaceByName(wgInterface.Name())

Some files were not shown because too many files have changed in this diff Show More