Compare commits

...

91 Commits

Author SHA1 Message Date
Pascal Fischer
affa8bf348 log checks 2025-07-24 19:54:53 +02:00
Pascal Fischer
c435c2727f [management] Log BufferUpdateAccountPeers caller (#4217) 2025-07-24 18:33:58 +02:00
Ali Amer
643730f770 [client] Correct minor issues in --filter-by-connection-type flag implementation for status command (#4214)
Signed-off-by: aliamerj <aliamer19ali@gmail.com>
2025-07-24 17:51:27 +02:00
Pascal Fischer
04fae00a6c [management] Log UpdateAccountPeers caller (#4216) 2025-07-24 17:44:48 +02:00
Pedro Maia Costa
1a9ea32c21 [management] scheduler cancel all jobs (#4158) 2025-07-24 16:25:21 +01:00
Pedro Maia Costa
0ea5d020a3 [management] extra settings integrated validator (#4136) 2025-07-24 16:12:29 +01:00
Viktor Liu
459c9ef317 [client] Add env and status flags for netbird service command (#3975) 2025-07-24 13:34:55 +02:00
Viktor Liu
e5e275c87a [client] Fix legacy routing exclusion routes in kernel mode (#4167) 2025-07-24 13:34:36 +02:00
Zoltan Papp
d311f57559 [ci] Temporarily disable race detection in Relay (#4210) 2025-07-24 13:14:49 +02:00
Zoltan Papp
1a28d18cde [client] Fix race issues in lazy tests (#4181)
* Fix race issues in lazy tests

* Fix test failure due to incorrect peer listener identification
2025-07-23 21:03:29 +02:00
Philippe Vaucher
91e7423989 [misc] Docker compose improvements (#4037)
* Use container defaults

* Remove docker compose version when generating zitadel config
2025-07-22 19:44:49 +02:00
Zoltan Papp
86c16cf651 [server, relay] Fix/relay race disconnection (#4174)
Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
2025-07-21 19:58:17 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
Zoltan Papp
0dab03252c [client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
2025-07-15 10:43:42 +02:00
iisteev
e49bcc343d [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145)
Signed-off-by: iisteev <isteevan.shetoo@is-info.fr>
2025-07-13 15:42:48 +02:00
Zoltan Papp
3e6eede152 [client] Fix elapsed time calculation when machine is in sleep mode (#4140) 2025-07-12 11:10:45 +02:00
Vlad
a76c8eafb4 [management] sync calls to UpdateAccountPeers from BufferUpdateAccountPeers (#4137)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-11 12:37:14 +03:00
Pedro Maia Costa
2b9f331980 always suffix ephemeral peer name (#4138) 2025-07-11 10:29:10 +01:00
Viktor Liu
a7ea881900 [client] Add rotated logs flag for debug bundle generation (#4100) 2025-07-10 16:13:53 +02:00
Vlad
8632dd15f1 [management] added cleanupWindow for collecting several ephemeral peers to delete (#4130)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-10 15:21:01 +02:00
Zoltan Papp
e3b40ba694 Update cli description of lazy connection (#4133) 2025-07-10 15:00:58 +02:00
Zoltan Papp
e59d75d56a Nil check in iface configurer (#4132) 2025-07-10 14:24:20 +02:00
Zoltan Papp
408f423adc [client] Disable pidfd check on Android 11 and below (#4127)
Disable pidfd check on Android 11 and below

On Android 11 (SDK <= 30) and earlier, pidfd-related system calls
are blocked by seccomp policies, causing SIGSYS crashes.

This change overrides `checkPidfdOnce` to return an error on
affected versions, preventing the use of unsupported pidfd features.
2025-07-09 22:16:08 +02:00
Misha Bragin
f17dd3619c [misc] update image in README.md (#4122) 2025-07-09 15:49:09 +02:00
Bethuel Mmbaga
969f1ed59a [management] Remove deleted user peers from groups on user deletion (#4121)
Refactors peer deletion to centralize group cleanup logic, ensuring deleted peers are consistently removed from all groups in one place.

- Removed redundant group removal code from DefaultAccountManager.DeletePeer
- Added group removal logic inside deletePeers to handle both single and multiple peer deletions
2025-07-09 10:14:10 +03:00
M. Essam
768ba24fda [management,rest] Add name/ip filters to peer management rest client (#4112) 2025-07-08 18:08:13 +02:00
Zoltan Papp
8942c40fde [client] Fix nil pointer exception in lazy connection (#4109)
Remove unused variable
2025-07-06 15:13:14 +02:00
Zoltan Papp
fbb1b55beb [client] refactor lazy detection (#4050)
This PR introduces a new inactivity package responsible for monitoring peer activity and notifying when peers become inactive.
Introduces a new Signal message type to close the peer connection after the idle timeout is reached.
Periodically checks the last activity of registered peers via a Bind interface.
Notifies via a channel when peers exceed a configurable inactivity threshold.
Default settings
DefaultInactivityThreshold is set to 15 minutes, with a minimum allowed threshold of 1 minute.

Limitations
This inactivity check does not support kernel WireGuard integration. In kernel–user space communication, the user space side will always be responsible for closing the connection.
2025-07-04 19:52:27 +02:00
Viktor Liu
77ec32dd6f [client] Implement dns routes for Android (#3989) 2025-07-04 16:43:11 +02:00
Bethuel Mmbaga
8c09a55057 [management] Log user id on account mismatch (#4101) 2025-07-04 10:51:58 +03:00
Pedro Maia Costa
f603ddf35e management: fix store get account peers without lock (#4092) 2025-07-04 08:44:08 +01:00
Krzysztof Nazarewski (kdn)
996b8c600c [management] replace invalid user with a clear error message about mismatched logins (#4097) 2025-07-03 16:36:36 +02:00
Maycon Santos
c4ed11d447 [client] Avoid logging setup keys on error message (#3962) 2025-07-03 16:22:18 +02:00
Viktor Liu
9afbecb7ac [client] Use unique sequence numbers for bsd routes (#4081)
updates the route manager on Unix to use a unique, incrementing sequence number for each route message instead of a fixed value.

Replace the static Seq: 1 with a call to r.getSeq()
Add an atomic seq field and the getSeq method in SysOps
2025-07-03 09:02:53 +02:00
Maycon Santos
2c81cf2c1e [management] Add account onboarding (#4084)
This PR introduces a new onboarding feature to handle such flows in the dashboard by defining an AccountOnboarding model, persisting it in the store, exposing CRUD operations in the manager and HTTP handlers, and updating API schemas and tests accordingly.

Add AccountOnboarding struct and embed it in Account
Extend Store and DefaultAccountManager with onboarding methods and SQL migrations
Update HTTP handlers, API types, OpenAPI spec, and add end-to-end tests
2025-07-03 09:01:32 +02:00
Pascal Fischer
551cb4e467 [management] expect specific error types on registration with setup key (#4094) 2025-07-02 20:04:28 +02:00
Maycon Santos
57961afe95 [doc] Add forum link (#4093)
* Add forum link

* Add forum link
2025-07-02 18:40:07 +02:00
Pascal Fischer
22678bce7f [management] add uniqueness constraint for peer ip and label and optimize generation (#4042) 2025-07-02 18:13:10 +02:00
Maycon Santos
6c633497bc [management] fix network update test for delete policy (#4086)
when adding a peer we calculate the network map an account using backpressure functions and some updates might arrive around the time we are deleting a policy.

This change ensures we wait enough time for the updates from add peer to be sent and read before continuing with the test logic
2025-07-02 12:25:31 +02:00
Carlos Hernandez
6922826919 [client] Support fullstatus without probes (#4052) 2025-07-02 10:42:47 +02:00
Maycon Santos
56a1a75e3f [client] Support random wireguard port on client (#4085)
Adds support for using a random available WireGuard port when the user specifies port `0`.

- Updates `freePort` logic to bind to the requested port (including `0`) without falling back to the default.
- Removes default port assignment in the configuration path, allowing `0` to propagate.
- Adjusts tests to handle dynamically assigned ports when using `0`.
2025-07-02 09:01:02 +02:00
Ali Amer
d9402168ad [management] Add option to disable default all-to-all policy (#3970)
This PR introduces a new configuration option `DisableDefaultPolicy` that prevents the creation of the default all-to-all policy when new accounts are created. This is useful for automation scenarios where explicit policies are preferred.
### Key Changes:
- Added DisableDefaultPolicy flag to the management server config
- Modified account creation logic to respect this flag
- Updated all test cases to explicitly pass the flag (defaulting to false to maintain backward compatibility)
- Propagated the flag through the account manager initialization chain

### Testing:

- Verified default behavior remains unchanged when flag is false
- Confirmed no default policy is created when flag is true
- All existing tests pass with the new parameter
2025-07-02 02:41:59 +02:00
Krzysztof Nazarewski (kdn)
dbdef04b9e [misc] getting-started-with-zitadel.sh: drop unnecessary port 8080 (#4075) 2025-07-02 02:35:13 +02:00
Maycon Santos
29cbfe8467 [misc] update sign pipeline version to v0.0.20 (#4082) 2025-07-01 16:23:31 +02:00
Maycon Santos
6ce8643368 [client] Run login popup on goroutine (#4080) 2025-07-01 13:45:55 +02:00
Krzysztof Nazarewski (kdn)
07d1ad35fc [misc] start the service after installation on arch linux (#4071) 2025-06-30 12:02:03 +02:00
Krzysztof Nazarewski (kdn)
ef6cd36f1a [misc] fix arch install.sh error with empty temporary dependencies
handle empty var before calling removal command
2025-06-30 11:59:35 +02:00
Krzysztof Nazarewski (kdn)
c1c71b6d39 [client] improve adding route log message (#4034)
from:
  Adding route to 1.2.3.4/32 via invalid IP @ 10 (wt0)
to:
  Adding route to 1.2.3.4/32 via no-ip @ 10 (wt0)
2025-06-30 11:57:42 +02:00
Pascal Fischer
0480507a10 [management] report networkmap duration in ms (#4064) 2025-06-28 11:38:15 +02:00
Krzysztof Nazarewski (kdn)
34ac4e4b5a [misc] fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG (#4055)
* fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG

fixes https://github.com/netbirdio/netbird/issues/4054

* un-quote the number

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-06-26 10:45:00 +02:00
Pascal Fischer
52ff9d9602 [management] remove unused transaction (#4053) 2025-06-26 01:34:22 +02:00
Pascal Fischer
1b73fae46e [management] add breakdown of network map calculation metrics (#4020) 2025-06-25 11:46:35 +02:00
Viktor Liu
d897365abc [client] Don't open cmd.exe during MSI actions (#4041) 2025-06-24 21:32:37 +02:00
Viktor Liu
f37aa2cc9d [misc] Specify netbird binary location in Dockerfiles (#4024) 2025-06-23 10:09:02 +02:00
Maycon Santos
5343bee7b2 [management] check and log on new management version (#4029)
This PR enhances the version checker to send a custom User-Agent header when polling for updates, and configures both the management CLI and client UI to use distinct agents. 

- NewUpdate now takes an `httpAgent` string to set the User-Agent header.
- `fetchVersion` builds a custom HTTP request (instead of `http.Get`) and sets the User-Agent.
- Management CLI and client UI now pass `"nb/management"` and `"nb/client-ui"` respectively to NewUpdate.
- Tests updated to supply an `httpAgent` constant.
- Logs if there is a new version available for management
2025-06-22 16:44:33 +02:00
Maycon Santos
870e29db63 [misc] add additional metrics (#4028)
* add additional metrics

we are collecting active rosenpass, ssh from the client side
we are also collecting active user peers and active users

* remove duplicated
2025-06-22 13:44:25 +02:00
Maycon Santos
08e9b05d51 [client] close windows when process needs to exit (#4027)
This PR fixes a bug by ensuring that the advanced settings and re-authentication windows are closed appropriately when the main GUI process exits.

- Updated runSelfCommand calls throughout the UI to pass a context parameter.
- Modified runSelfCommand’s signature and its internal command invocation to use exec.CommandContext for proper cancellation handling.
2025-06-22 10:33:04 +02:00
hakansa
3581648071 [client] Refactor showLoginURL to improve error handling and connection status checks (#4026)
This PR refactors showLoginURL to improve error handling and connection status checks by delaying the login fetch until user interaction and closing the pop-up if already connected.

- Moved s.login(false) call into the click handler to defer network I/O.
- Added a conn.Status check after opening the URL to skip reconnection if already connected.
- Enhanced error logs for missing verification URLs and service status failures.
2025-06-22 10:03:58 +02:00
Viktor Liu
2a51609436 [client] Handle lazy routing peers that are part of HA groups (#3943)
* Activate new lazy routing peers if the HA group is active
* Prevent lazy peers going to idle if HA group members are active (#3948)
2025-06-20 18:07:19 +02:00
Pascal Fischer
83457f8b99 [management] add transaction for integrated validator groups update and primary account update (#4014) 2025-06-20 12:13:24 +02:00
Pascal Fischer
b45284f086 [management] export ephemeral peer flag on api (#4004) 2025-06-19 16:46:56 +02:00
Bethuel Mmbaga
e9016aecea [management] Add backward compatibility for older clients without firewall rules port range support (#4003)
Adds backward compatibility for clients with versions prior to v0.48.0 that do not support port range firewall rules.

- Skips generation of firewall rules with multi-port ranges for older clients
- Preserves support for single-port ranges by treating them as individual port rules, ensuring compatibility with older clients
2025-06-19 13:07:06 +03:00
Viktor Liu
23b5d45b68 [client] Fix port range squashing (#4007) 2025-06-18 18:56:48 +02:00
Viktor Liu
0e5dc9d412 [client] Add more Android advanced settings (#4001) 2025-06-18 17:23:23 +02:00
Zoltan Papp
91f7ee6a3c Fix route notification
On Android ignore the dynamic roots in the route notifications
2025-06-18 16:49:03 +02:00
Bethuel Mmbaga
7c6b85b4cb [management] Refactor routes to use store methods (#2928) 2025-06-18 16:40:29 +03:00
hakansa
08c9107c61 [client] fix connection state handling (#3995)
[client] fix connection state handling (#3995)
2025-06-17 17:14:08 +03:00
hakansa
81d83245e1 [client] Fix logic in updateStatus to correctly handle connection state (#3994)
[client] Fix logic in updateStatus to correctly handle connection state (#3994)
2025-06-17 17:02:04 +03:00
Maycon Santos
af2b427751 [management] Avoid recalculating next peer expiration (#3991)
* Avoid recalculating next peer expiration

- Check if an account schedule is already running
- Cancel executing schedules only when changes occurs
- Add more context info to logs

* fix tests
2025-06-17 15:14:11 +02:00
hakansa
f61ebdb3bc [client] Fix DNS Interceptor Build Error (#3993)
[client] Fix DNS Interceptor Build Error
2025-06-17 16:07:14 +03:00
Viktor Liu
de7384e8ea [client] Tighten allowed domains for dns forwarder (#3978) 2025-06-17 14:03:00 +02:00
Viktor Liu
75c1be69cf [client] Prioritze the local resolver in the dns handler chain (#3965) 2025-06-17 14:02:30 +02:00
hakansa
424ae28de9 [client] Fix UI Download URL (#3990)
[client] Fix UI Download URL
2025-06-17 11:55:48 +03:00
Viktor Liu
d4a800edd5 [client] Fix status recorder panic (#3988) 2025-06-17 01:20:26 +02:00
Maycon Santos
dd9917f1a8 [misc] add missing images (#3987) 2025-06-16 21:05:49 +02:00
Viktor Liu
8df8c1012f [client] Support wildcard DNS on iOS (#3979) 2025-06-16 18:33:51 +02:00
Viktor Liu
bfa5c21d2d [client] Improve icmp conntrack log (#3963) 2025-06-16 10:12:59 +02:00
Maycon Santos
b1247a14ba [management] Use xID for setup key IDs to avoid id collisions (#3977)
This PR addresses potential ID collisions by switching the setup key ID generation from a hash-based approach to using xid-generated IDs.

Replace the hash function with xid.New().String()
Remove obsolete imports and the Hash() function
2025-06-14 12:24:16 +01:00
Philippe Vaucher
f595057a0b [signal] Set flags from environment variables (#3972) 2025-06-14 00:08:34 +02:00
hakansa
089d442fb2 [client] Display login popup on session expiration (#3955)
This PR implements a feature enhancement to display a login popup when the session expires. Key changes include updating flag handling and client construction to support a new login URL popup, revising login and notification handling logic to use the new popup, and updating status and server-side session state management accordingly.
2025-06-13 23:51:57 +02:00
Viktor Liu
04a3765391 [client] Fix unncessary UI updates (#3785) 2025-06-13 20:38:50 +02:00
Zoltan Papp
d24d8328f9 [client] Propagation networks for Android client (#3966)
Add networks propagation
2025-06-13 11:04:17 +02:00
Vlad
4f63996ae8 [management] added events streaming metrics (#3814) 2025-06-12 18:48:54 +01:00
278 changed files with 11539 additions and 5322 deletions

View File

@@ -16,6 +16,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
- uses: git-town/action@v1.2.1
with:
skip-single-stacks: true
skip-single-stacks: true

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -24,8 +24,8 @@ jobs:
id: filter
with:
filters: |
management:
- 'management/**'
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
@@ -148,7 +148,7 @@ jobs:
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -181,6 +181,7 @@ jobs:
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
CONTAINER: "true"
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
@@ -198,6 +199,7 @@ jobs:
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
@@ -211,7 +213,11 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: ""
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -251,9 +257,9 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m ./signal/...
-timeout 10m ./relay/...
test_signal:
name: "Signal / Unit"

View File

@@ -43,7 +43,7 @@ jobs:
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.18"
SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -134,6 +134,7 @@ jobs:
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
run: |
set -x
@@ -180,6 +181,7 @@ jobs:
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
- name: Install modules
run: go mod tidy

View File

@@ -149,6 +149,7 @@ nfpms:
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids:
- netbird
goarch: amd64
@@ -164,6 +165,7 @@ dockers:
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids:
- netbird
goarch: arm64

View File

@@ -14,6 +14,9 @@
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
@@ -29,13 +32,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/>
</strong>
<br>
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
</p>
@@ -47,10 +50,9 @@
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open-Source Network Security in a Single Platform
### Open Source Network Security in a Single Platform
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -1,6 +1,9 @@
FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,6 +1,7 @@
FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird

View File

@@ -59,10 +59,14 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
@@ -106,8 +110,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -132,8 +136,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// Stop the internal client and free the resources
@@ -174,6 +178,55 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos}
}
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()

26
client/android/exec.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build android
package android
import (
"fmt"
_ "unsafe"
)
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
// In Android version 11 and earlier, pidfd-related system calls
// are not allowed by the seccomp policy, which causes crashes due
// to SIGSYS signals.
//go:linkname checkPidfdOnce os.checkPidfdOnce
var checkPidfdOnce func() error
func execWorkaround(androidSDKVersion int) {
if androidSDKVersion > 30 { // above Android 11
return
}
checkPidfdOnce = func() error {
return fmt.Errorf("unsupported Android version")
}
}

View File

@@ -0,0 +1,27 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,30 +7,23 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum
}
// PeerInfoCollection made for Java layer to get non default types as collection
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
// PeerInfoArray is a wrapper of []PeerInfo
type PeerInfoArray struct {
items []PeerInfo
}
// Add new PeerInfo to the collection
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
array.items = append(array.items, s)
return array
}
// Get return an element of the collection
func (array PeerInfoArray) Get(i int) *PeerInfo {
func (array *PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i]
}
// Size return with the size of the collection
func (array PeerInfoArray) Size() int {
func (array *PeerInfoArray) Size() int {
return len(array.items)
}

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal"
)
// Preferences export a subset of the internal config for gomobile
// Preferences exports a subset of the internal config for gomobile
type Preferences struct {
configInput internal.ConfigInput
}
// NewPreferences create new Preferences instance
// NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{
ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci}
}
// GetManagementURL read url from config file
// GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err
}
// SetManagementURL store the given url and wait for commit
// SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url
}
// GetAdminURL read url from config file
// GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err
}
// SetAdminURL store the given url and wait for commit
// SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url
}
// GetPreSharedKey read preshared key from config file
// GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err
}
// SetPreSharedKey store the given key and wait for commit
// SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled store if rosenpass is enabled
// SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
// GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
// SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
@@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)
return err

View File

@@ -17,10 +17,18 @@ import (
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const errCloseConnection = "Failed to close connection: %v"
var (
logFileCount uint32
systemInfoFlag bool
uploadBundleFlag bool
uploadBundleURLFlag string
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -88,12 +96,13 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -105,7 +114,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -223,12 +232,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -255,7 +265,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -297,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
)
}
return statusOutputString
@@ -375,3 +385,15 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
@@ -38,10 +37,7 @@ const (
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
)
var (
@@ -71,14 +67,10 @@ var (
interfaceName string
wireguardPort uint16
networkMonitor bool
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{
@@ -123,15 +115,9 @@ func init() {
defaultDaemonAddr = "tcp://127.0.0.1:41731"
}
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
@@ -142,7 +128,6 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
rootCmd.AddCommand(statusCmd)
@@ -153,9 +138,6 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -184,11 +166,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success
@@ -196,14 +175,13 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
termCh := make(chan os.Signal, 1)
signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
go func() {
done := ctx.Done()
defer cancel()
select {
case <-done:
case <-ctx.Done():
case <-termCh:
}
log.Info("shutdown signal received")
cancel()
}()
}

View File

@@ -1,12 +1,15 @@
//go:build !ios && !android
package cmd
import (
"context"
"fmt"
"runtime"
"strings"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
@@ -14,6 +17,16 @@ import (
"github.com/netbirdio/netbird/client/server"
)
var serviceCmd = &cobra.Command{
Use: "service",
Short: "manages Netbird service",
}
var (
serviceName string
serviceEnvVars []string
)
type program struct {
ctx context.Context
cancel context.CancelFunc
@@ -22,12 +35,31 @@ type program struct {
serverInstanceMu sync.Mutex
}
func init() {
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
rootCmd.AddCommand(serviceCmd)
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
ctx = internal.CtxInitState(ctx)
return &program{ctx: ctx, cancel: cancel}
}
func newSVCConfig() *service.Config {
func newSVCConfig() (*service.Config, error) {
config := &service.Config{
Name: serviceName,
DisplayName: "Netbird",
@@ -36,23 +68,47 @@ func newSVCConfig() *service.Config {
EnvVars: make(map[string]string),
}
if len(serviceEnvVars) > 0 {
extraEnvs, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
return nil, fmt.Errorf("parse service environment variables: %w", err)
}
config.EnvVars = extraEnvs
}
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
return config, nil
}
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
s, err := service.New(prg, conf)
if err != nil {
log.Fatal(err)
return nil, err
}
return s, nil
return service.New(prg, conf)
}
var serviceCmd = &cobra.Command{
Use: "service",
Short: "manages Netbird service",
func parseServiceEnvVars(envVars []string) (map[string]string, error) {
envMap := make(map[string]string)
for _, env := range envVars {
if env == "" {
continue
}
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("empty environment variable key in: %s", env)
}
envMap[key] = value
}
return envMap, nil
}

View File

@@ -1,3 +1,5 @@
//go:build !ios && !android
package cmd
import (
@@ -47,14 +49,13 @@ func (p *program) Start(svc service.Service) error {
listen, err := net.Listen(split[0], split[1])
if err != nil {
return fmt.Errorf("failed to listen daemon interface: %w", err)
return fmt.Errorf("listen daemon interface: %w", err)
}
go func() {
defer listen.Close()
if split[0] == "unix" {
err = os.Chmod(split[1], 0666)
if err != nil {
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
@@ -100,37 +101,49 @@ func (p *program) Stop(srv service.Service) error {
return nil
}
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())
if err := handleRebrand(cmd); err != nil {
return nil, err
}
if err := util.InitLog(logLevel, logFile); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
cfg, err := newSVCConfig()
if err != nil {
return nil, fmt.Errorf("create service config: %w", err)
}
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return nil, err
}
return s, nil
}
var runCmd = &cobra.Command{
Use: "run",
Short: "runs Netbird as service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
err = s.Run()
if err != nil {
return err
}
return nil
return s.Run()
},
}
@@ -138,31 +151,14 @@ var startCmd = &cobra.Command{
Use: "start",
Short: "starts Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
cmd.PrintErrln(err)
return err
}
err = s.Start()
if err != nil {
cmd.PrintErrln(err)
return err
if err := s.Start(); err != nil {
return fmt.Errorf("start service: %w", err)
}
cmd.Println("Netbird service has been started")
return nil
@@ -173,29 +169,14 @@ var stopCmd = &cobra.Command{
Use: "stop",
Short: "stops Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
err = s.Stop()
if err != nil {
return err
if err := s.Stop(); err != nil {
return fmt.Errorf("stop service: %w", err)
}
cmd.Println("Netbird service has been stopped")
return nil
@@ -206,31 +187,48 @@ var restartCmd = &cobra.Command{
Use: "restart",
Short: "restarts Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
err = s.Restart()
if err != nil {
return err
if err := s.Restart(); err != nil {
return fmt.Errorf("restart service: %w", err)
}
cmd.Println("Netbird service has been restarted")
return nil
},
}
var svcStatusCmd = &cobra.Command{
Use: "status",
Short: "shows Netbird service status",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
status, err := s.Status()
if err != nil {
return fmt.Errorf("get service status: %w", err)
}
var statusText string
switch status {
case service.StatusRunning:
statusText = "Running"
case service.StatusStopped:
statusText = "Stopped"
case service.StatusUnknown:
statusText = "Unknown"
default:
statusText = fmt.Sprintf("Unknown (%d)", status)
}
cmd.Printf("Netbird service status: %s\n", statusText)
return nil
},
}

View File

@@ -1,87 +1,121 @@
//go:build !ios && !android
package cmd
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
var ErrGetServiceStatus = fmt.Errorf("failed to get service status")
// Common service command setup
func setupServiceCommand(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())
return handleRebrand(cmd)
}
// Build service arguments for install/reconfigure
func buildServiceArguments() []string {
args := []string{
"service",
"run",
"--config",
configPath,
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {
args = append(args, "--management-url", managementURL)
}
if logFile != "" {
args = append(args, "--log-file", logFile)
}
return args
}
// Configure platform-specific service settings
func configurePlatformSpecificSettings(svcConfig *service.Config) error {
if runtime.GOOS == "linux" {
// Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile != "console" {
setStdLogPath := true
dir := filepath.Dir(logFile)
if _, err := os.Stat(dir); err != nil {
if err = os.MkdirAll(dir, 0750); err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
return nil
}
// Create fully configured service config for install/reconfigure
func createServiceConfigForInstall() (*service.Config, error) {
svcConfig, err := newSVCConfig()
if err != nil {
return nil, fmt.Errorf("create service config: %w", err)
}
svcConfig.Arguments = buildServiceArguments()
if err = configurePlatformSpecificSettings(svcConfig); err != nil {
return nil, fmt.Errorf("configure platform-specific settings: %w", err)
}
return svcConfig, nil
}
var installCmd = &cobra.Command{
Use: "install",
Short: "installs Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
if err := setupServiceCommand(cmd); err != nil {
return err
}
svcConfig := newSVCConfig()
svcConfig.Arguments = []string{
"service",
"run",
"--config",
configPath,
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
}
if logFile != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
}
if runtime.GOOS == "linux" {
// Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile != "console" {
setStdLogPath := true
dir := filepath.Dir(logFile)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0750)
if err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil {
cmd.PrintErrln(err)
return err
}
err = s.Install()
if err != nil {
cmd.PrintErrln(err)
return err
if err := s.Install(); err != nil {
return fmt.Errorf("install service: %w", err)
}
cmd.Println("Netbird service has been installed")
@@ -93,27 +127,109 @@ var uninstallCmd = &cobra.Command{
Use: "uninstall",
Short: "uninstalls Netbird service from system",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := setupServiceCommand(cmd); err != nil {
return err
}
cmd.SetOut(cmd.OutOrStdout())
cfg, err := newSVCConfig()
if err != nil {
return fmt.Errorf("create service config: %w", err)
}
err := handleRebrand(cmd)
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return err
}
if err := s.Uninstall(); err != nil {
return fmt.Errorf("uninstall service: %w", err)
}
cmd.Println("Netbird service has been uninstalled")
return nil
},
}
var reconfigureCmd = &cobra.Command{
Use: "reconfigure",
Short: "reconfigures Netbird service with new settings",
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install.
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil {
return err
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil {
return err
return fmt.Errorf("create service: %w", err)
}
err = s.Uninstall()
if err != nil {
return err
if wasRunning {
cmd.Println("Stopping Netbird service...")
if err := s.Stop(); err != nil {
cmd.Printf("Warning: failed to stop service: %v\n", err)
}
}
cmd.Println("Netbird service has been uninstalled")
cmd.Println("Removing existing service configuration...")
if err := s.Uninstall(); err != nil {
return fmt.Errorf("uninstall existing service: %w", err)
}
cmd.Println("Installing service with new configuration...")
if err := s.Install(); err != nil {
return fmt.Errorf("install service with new config: %w", err)
}
if wasRunning {
cmd.Println("Starting Netbird service...")
if err := s.Start(); err != nil {
return fmt.Errorf("start service after reconfigure: %w", err)
}
cmd.Println("Netbird service has been reconfigured and started")
} else {
cmd.Println("Netbird service has been reconfigured")
}
return nil
},
}
func isServiceRunning() (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return false, err
}
status, err := s.Status()
if err != nil {
return false, fmt.Errorf("%w: %w", ErrGetServiceStatus, err)
}
return status == service.StatusRunning, nil
}

263
client/cmd/service_test.go Normal file
View File

@@ -0,0 +1,263 @@
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) {
tests := []struct {
name string
envVars []string
expected map[string]string
expectErr bool
}{
{
name: "Valid single env var",
envVars: []string{"LOG_LEVEL=debug"},
expected: map[string]string{
"LOG_LEVEL": "debug",
},
},
{
name: "Valid multiple env vars",
envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"},
expected: map[string]string{
"LOG_LEVEL": "debug",
"CUSTOM_VAR": "value",
},
},
{
name: "Env var with spaces",
envVars: []string{" KEY = value "},
expected: map[string]string{
"KEY": "value",
},
},
{
name: "Invalid format - no equals",
envVars: []string{"INVALID"},
expectErr: true,
},
{
name: "Invalid format - empty key",
envVars: []string{"=value"},
expectErr: true,
},
{
name: "Empty value is valid",
envVars: []string{"KEY="},
expected: map[string]string{
"KEY": "",
},
},
{
name: "Empty slice",
envVars: []string{},
expected: map[string]string{},
},
{
name: "Empty string in slice",
envVars: []string{"", "KEY=value", ""},
expected: map[string]string{"KEY": "value"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseServiceEnvVars(tt.envVars)
if tt.expectErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// TestServiceConfigWithEnvVars tests service config creation with env vars
func TestServiceConfigWithEnvVars(t *testing.T) {
originalServiceName := serviceName
originalServiceEnvVars := serviceEnvVars
defer func() {
serviceName = originalServiceName
serviceEnvVars = originalServiceEnvVars
}()
serviceName = "test-service"
serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"}
cfg, err := newSVCConfig()
require.NoError(t, err)
assert.Equal(t, "test-service", cfg.Name)
assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"])
assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"])
if runtime.GOOS == "linux" {
assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"])
}
}

View File

@@ -26,6 +26,7 @@ var (
statusFilter string
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
)
var statusCmd = &cobra.Command{
@@ -45,6 +46,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -69,7 +71,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err
}
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -86,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
var statusOutputString string
switch {
case detailFlag:
@@ -117,7 +122,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -153,6 +158,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag()
}
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil
}

View File

@@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
"This overrides any policies received from the management service.")
}

View File

@@ -103,13 +103,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}

View File

@@ -3,6 +3,7 @@ package conntrack
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
@@ -19,6 +20,10 @@ const (
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
)
// ICMPConnKey uniquely identifies an ICMP connection
@@ -29,7 +34,7 @@ type ICMPConnKey struct {
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID)
}
// ICMPConnTrack represents an ICMP connection state
@@ -50,6 +55,72 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger
}
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 {
@@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
func (t *ICMPTracker) TrackOutbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
func (t *ICMPTracker) TrackInbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
func (t *ICMPTracker) track(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
}
})
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
}
b.ResetTimer()

View File

@@ -104,6 +104,12 @@ type Manager struct {
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
}
// decoder for packages
@@ -189,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
}
m.routingEnabled.Store(false)
@@ -519,22 +526,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -581,14 +572,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
// FilterOutBound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
// FilterInbound filters incoming packets
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
return m.filterInbound(packetData, size)
}
// UpdateLocalIPs updates the list of local IPs
@@ -596,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -618,8 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true
}
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
return false
}
@@ -671,7 +662,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
}
}
@@ -684,7 +675,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
}
}
@@ -723,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false
}
// dropFilter implements filtering logic for incoming packets.
// filterInbound implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, size int) bool {
func (m *Manager) filterInbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -747,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return false
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false
}

View File

@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.filterInbound(inbound, 0)
}
})
}
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
}
// Test packet
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut, 0)
manager.filterOutbound(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0)
manager.filterInbound(testIn, 0)
}
})
}
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.filterInbound(inbound, 0)
}
})
}
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.filterOutbound(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.filterInbound(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.filterOutbound(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.filterInbound(inbound, 0)
}
})
}
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.filterOutbound(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.filterInbound(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.filterOutbound(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.filterOutbound(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
}
})
}
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
// Data transfer
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
// Connection teardown
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
}
})
}
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.filterOutbound(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.filterInbound(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.filterOutbound(ack, 0)
}
// Pre-generate test packets
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
manager.filterOutbound(outPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
}
})
})
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
}
})
})

View File

@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)

View File

@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), 0) {
if m.filterInbound(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0)
result := manager.filterOutbound(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
result = manager.filterOutbound(buf.Bytes(), 0)
require.False(t, result)
}
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0)
drop = manager.filterInbound(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn)
}
}

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr)
}
var rxPackets, txPackets uint64

View File

@@ -0,0 +1,408 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@@ -0,0 +1,416 @@
package uspfilter
import (
"fmt"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
)
// BenchmarkDNATTranslation measures the performance of DNAT operations
func BenchmarkDNATTranslation(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
description string
}{
{
name: "tcp_with_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: true,
description: "TCP packet with DNAT translation enabled",
},
{
name: "tcp_without_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
description: "TCP packet without DNAT (baseline)",
},
{
name: "udp_with_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: true,
description: "UDP packet with DNAT translation enabled",
},
{
name: "udp_without_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
description: "UDP packet without DNAT (baseline)",
},
{
name: "icmp_with_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: true,
description: "ICMP packet with DNAT translation enabled",
},
{
name: "icmp_without_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: false,
description: "ICMP packet without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mapping if needed
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
if sc.setupDNAT {
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Create test packets
srcIP := netip.MustParseAddr("172.16.0.1")
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
// Pre-establish connection for reverse DNAT test
if sc.setupDNAT {
manager.filterOutbound(outboundPacket, 0)
}
b.ResetTimer()
// Benchmark outbound DNAT translation
b.Run("outbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
// Benchmark inbound reverse DNAT translation
if sc.setupDNAT {
b.Run("inbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
manager.filterInbound(packet, 0)
}
})
}
})
}
}
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup multiple DNAT mappings
numMappings := 100
originalIPs := make([]netip.Addr, numMappings)
translatedIPs := make([]netip.Addr, numMappings)
for i := 0; i < numMappings; i++ {
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
require.NoError(b, err)
}
srcIP := netip.MustParseAddr("172.16.0.1")
// Pre-generate packets
outboundPackets := make([][]byte, numMappings)
inboundPackets := make([][]byte, numMappings)
for i := 0; i < numMappings; i++ {
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
// Establish connections
manager.filterOutbound(outboundPackets[i], 0)
}
b.ResetTimer()
b.Run("concurrent_outbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
i++
}
})
})
b.Run("concurrent_inbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
manager.filterInbound(packet, 0)
i++
}
})
})
}
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
func BenchmarkDNATScaling(b *testing.B) {
mappingCounts := []int{1, 10, 100, 1000}
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mappings
for i := 0; i < count; i++ {
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Test with the last mapping added (worst case for lookup)
srcIP := netip.MustParseAddr("172.16.0.1")
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
}
}
// generateDNATTestPacket creates a test packet for DNAT benchmarking
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
tb.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: proto,
}
var transportLayer gopacket.SerializableLayer
switch proto {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
case layers.IPProtocolICMPv4:
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
transportLayer = icmp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(tb, err)
return buf.Bytes()
}
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
func BenchmarkChecksumUpdate(b *testing.B) {
// Create test data for checksum calculations
testData := make([]byte, 64) // Typical packet size for checksum testing
for i := range testData {
testData[i] = byte(i)
}
b.Run("ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
}
})
b.Run("icmp_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = icmpChecksum(testData)
}
})
b.Run("incremental_update", func(b *testing.B) {
oldBytes := []byte{192, 168, 1, 100}
newBytes := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
}
})
}
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Create fresh packet each time to isolate allocation testing
testPacket := make([]byte, len(packet))
copy(testPacket, packet)
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
}
}
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
func BenchmarkDirectIPExtraction(b *testing.B) {
// Create a test packet
srcIP := netip.MustParseAddr("172.16.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
b.Run("direct_byte_access", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Direct extraction from packet bytes
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
}
})
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
// Extract using decoder (traditional method)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
_ = dst
}
})
}
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
func BenchmarkChecksumOptimizations(b *testing.B) {
// Create test IPv4 header (20 bytes)
header := make([]byte, 20)
for i := range header {
header[i] = byte(i)
}
// Clear checksum field
header[10] = 0
header[11] = 0
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(header)
}
})
// Test incremental checksum updates
oldIP := []byte{192, 168, 1, 100}
newIP := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
b.Run("optimized_incremental_update", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
}
})
}

View File

@@ -0,0 +1,145 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
// Add DNAT mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcPort uint16
dstPort uint16
}{
{"TCP", layers.IPProtocolTCP, 12345, 80},
{"UDP", layers.IPProtocolUDP, 12345, 53},
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test outbound DNAT translation
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound packet (should translate destination)
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, translated, "Outbound packet should be translated")
// Verify destination IP was changed
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
// Test inbound reverse DNAT translation
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet (should reverse translate source)
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, reversed, "Inbound packet should be reverse translated")
// Verify source IP was changed back to original
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
// Test that checksums are recalculated correctly
if tc.protocol != layers.IPProtocolICMPv4 {
// For TCP/UDP, verify the transport checksum was updated
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
}
})
}
}
// parsePacket helper to create a decoder for testing
func parsePacket(t testing.TB, packetData []byte) *decoder {
t.Helper()
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
// Test adding mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
// Verify mapping exists
result, exists := manager.getDNATTranslation(originalIP)
require.True(t, exists)
require.Equal(t, translatedIP, result)
// Test reverse lookup
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
require.True(t, exists)
require.Equal(t, originalIP, reverseResult)
// Test removing mapping
err = manager.RemoveInternalDNATMapping(originalIP)
require.NoError(t, err)
// Verify mapping no longer exists
_, exists = manager.getDNATTranslation(originalIP)
require.False(t, exists)
_, exists = manager.findReverseDNATMapping(translatedIP)
require.False(t, exists)
// Test error cases
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
require.Error(t, err, "Should reject invalid original IP")
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
require.Error(t, err, "Should reject invalid translated IP")
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}

View File

@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
dropped := m.filterOutbound(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -0,0 +1,96 @@
package bind
import (
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/monotime"
)
const (
saveFrequency = int64(5 * time.Second)
)
type PeerRecord struct {
Address netip.AddrPort
LastActivity atomic.Int64 // UnixNano timestamp
}
type ActivityRecorder struct {
mu sync.RWMutex
peers map[string]*PeerRecord // publicKey to PeerRecord map
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
}
func NewActivityRecorder() *ActivityRecorder {
return &ActivityRecorder{
peers: make(map[string]*PeerRecord),
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
}
}
// GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
r.mu.RLock()
defer r.mu.RUnlock()
activities := make(map[string]monotime.Time, len(r.peers))
for key, record := range r.peers {
monoTime := record.LastActivity.Load()
activities[key] = monotime.Time(monoTime)
}
return activities
}
// UpsertAddress adds or updates the address for a publicKey
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
r.mu.Lock()
defer r.mu.Unlock()
var record *PeerRecord
record, exists := r.peers[publicKey]
if exists {
delete(r.addrToPeer, record.Address)
record.Address = address
} else {
record = &PeerRecord{
Address: address,
}
record.LastActivity.Store(int64(monotime.Now()))
r.peers[publicKey] = record
}
r.addrToPeer[address] = record
}
func (r *ActivityRecorder) Remove(publicKey string) {
r.mu.Lock()
defer r.mu.Unlock()
if record, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, record.Address)
delete(r.peers, publicKey)
}
}
// record updates LastActivity for the given address using atomic store
func (r *ActivityRecorder) record(address netip.AddrPort) {
r.mu.RLock()
record, ok := r.addrToPeer[address]
r.mu.RUnlock()
if !ok {
log.Warnf("could not find record for address %s", address)
return
}
now := int64(monotime.Now())
last := record.LastActivity.Load()
if now-last < saveFrequency {
return
}
_ = record.LastActivity.CompareAndSwap(last, now)
}

View File

@@ -0,0 +1,25 @@
package bind
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/monotime"
)
func TestActivityRecorder_GetLastActivities(t *testing.T) {
peer := "peer1"
ar := NewActivityRecorder()
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
activities := ar.GetLastActivities()
p, ok := activities[peer]
if !ok {
t.Fatalf("Expected activity for peer %s, but got none", peer)
}
if monotime.Since(p) > 5*time.Second {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
}
}

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -1,6 +1,7 @@
package bind
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -15,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type RecvMessage struct {
@@ -51,22 +53,24 @@ type ICEBind struct {
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
activityRecorder *ActivityRecorder
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
activityRecorder: NewActivityRecorder(),
}
rc := receiverCreator{
@@ -100,6 +104,10 @@ func (s *ICEBind) Close() error {
return s.StdNetBind.Close()
}
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
@@ -146,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
@@ -199,6 +207,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
}
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
@@ -257,6 +270,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
if isTransportPkg(buffs, sizes[0]) {
if ep, ok := eps[0].(*Endpoint); ok {
c.activityRecorder.record(ep.AddrPort)
}
}
return 1, nil
}
}
@@ -272,3 +292,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
}
msgsPool.Put(msgs)
}
func isTransportPkg(buffers [][]byte, n int) bool {
// The first buffer should contain at least 4 bytes for type
if len(buffers[0]) < 4 {
return true
}
// WireGuard packet type is a little-endian uint32 at start
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
// Check if packetType matches known WireGuard message types
if packetType == 4 && n > 32 {
return true
}
return false
}

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
var allAddresses []string
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
delete(m.addressMap, addr)
}
allAddresses = append(allAddresses, addresses...)
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
}
}
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
m.addressMapMu.RUnlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,22 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
return
}
// Userspace mode: UDPConn wrapper around nbnet.PacketConn
if wrapped, ok := m.params.UDPConn.(*UDPConn); ok {
if conn, ok := wrapped.GetPacketConn().(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
}
}
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
}
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type UDPConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
// GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr)
}
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) performFilterCheck(addr net.Addr) error {
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
@@ -276,3 +278,7 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}

View File

@@ -16,6 +16,8 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -36,16 +38,18 @@ const (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
wgCfg.startUAPI()
return wgCfg
@@ -87,7 +91,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return ipcErr
}
if endpoint != nil {
addr, err := netip.ParseAddr(endpoint.IP.String())
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
@@ -104,7 +120,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
c.activityRecorder.Remove(peerKey)
return ipcErr
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
@@ -205,6 +224,10 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
return parseStatus(c.deviceName, ipcStr)
}
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
return c.activityRecorder.GetLastActivities()
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error

View File

@@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
name string
device *device.Device
@@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
}
}
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
@@ -70,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -9,11 +9,11 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
// FilterOutbound filter outgoing packets from host to external destinations
FilterOutbound(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) {
if !filter.FilterInbound(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -99,8 +100,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if err != nil {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
UDPConn: udpConn,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,

View File

@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()

View File

@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/monotime"
)
type WGConfigurer interface {
@@ -19,4 +20,5 @@ type WGConfigurer interface {
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
)
const (
@@ -29,6 +30,11 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
var (
// ErrIfaceNotFound is returned when the WireGuard interface is not found
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -43,6 +49,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
DisableDNS bool
}
// WGIface represents an interface instance
@@ -116,6 +123,9 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
@@ -125,6 +135,9 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.RemovePeer(peerKey)
@@ -134,6 +147,9 @@ func (w *WGIface) RemovePeer(peerKey string) error {
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP)
@@ -143,6 +159,9 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
@@ -213,10 +232,29 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
// GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.GetStats()
}
func (w *WGIface) LastActivities() map[string]monotime.Time {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return nil
}
return w.configurer.LastActivities()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.FullStats()
}

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil

View File

@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.

View File

@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.

View File

@@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
var skipProxy bool
if val := os.Getenv(EnvSkipProxy); val != "" {
skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
}
if skipProxy {
return nsTunDev, tunNet, nil

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type ProxyBind struct {
@@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
}
// AddTurnConn adds a new connection to the bind.
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
}
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}

View File

@@ -11,6 +11,8 @@ import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
@@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil {
return 0, ctx.Err()
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}

View File

@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
func (w *KernelFactory) Free() error {

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
}
func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{
Bind: w.bind,
}
return proxyBind.NewProxyBind(w.bind)
}
func (w *USPFactory) Free() error {

View File

@@ -0,0 +1,19 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
SetDisconnectListener(disconnected func())
}

View File

@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// WGUDPProxy proxies
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{
localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
}
return p
}
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr
}
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() {
if p.remoteConn == nil {
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err)
return
}

View File

@@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},

View File

@@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{

View File

@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
}
if _, err := config.apply(input); err != nil {
@@ -317,10 +319,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
@@ -416,9 +414,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
if runtime.GOOS == "android" {
// default to disabled SSH on Android for security
log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
@@ -26,11 +25,11 @@ import (
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
enabledLocally bool
rosenpassEnabled bool
lazyConnMgr *manager.Manager
@@ -39,12 +38,12 @@ type ConnMgr struct {
lazyCtxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
rosenpassEnabled: engineConfig.RosenpassEnabled,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
@@ -64,6 +63,11 @@ func (e *ConnMgr) Start(ctx context.Context) {
return
}
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
@@ -83,7 +87,12 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
return nil
}
log.Warnf("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager()
@@ -133,7 +142,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
@@ -201,7 +210,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
if !ok {
return
}
defer conn.Close()
defer conn.Close(false)
if !e.isStartedWithLazyMgr() {
return
@@ -211,23 +220,27 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return conn, true
return
}
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
// If locally the lazy connection is disabled, we force the peer connection open.
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
}
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
}
func (e *ConnMgr) Close() {
@@ -244,7 +257,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
@@ -275,7 +288,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {

View File

@@ -17,7 +17,6 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -526,17 +525,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
addr := net.UDPAddr{Port: initPort}
conn, err := net.ListenUDP("udp", &addr)
if err == nil {
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
closeConnWithLog(conn)
return initPort, nil
return returnPort, nil
}
// if the port is already in use, ask the system for a free port

View File

@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
shouldMatch bool
}{
{
name: "not provided, fallback to default",
name: "when port is 0 use random port",
port: 0,
want: 51820,
shouldMatch: true,
want: 0,
shouldMatch: false,
},
{
name: "provided and available",
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
shouldMatch: false,
},
}
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
if err != nil {
t.Errorf("freePort error = %v", err)
}
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
_ = c1.Close()
}(c1)
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
tests[1].port++
tests[1].want++
}
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -167,6 +167,7 @@ type BundleGenerator struct {
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
archive *zip.Writer
}
@@ -175,6 +176,7 @@ type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
type GeneratorDependencies struct {
@@ -185,6 +187,12 @@ type GeneratorDependencies struct {
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
// Default to 1 log file for backward compatibility when 0 is provided
logFileCount := cfg.LogFileCount
if logFileCount == 0 {
logFileCount = 1
}
return &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
@@ -196,6 +204,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
}
@@ -561,32 +570,8 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err)
}
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
// add rotated log files based on logFileCount
g.addRotatedLogFiles(logDir)
stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
@@ -670,6 +655,52 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
return nil
}
// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount
func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
if g.logFileCount == 0 {
return
}
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
return
}
if len(files) == 0 {
return
}
// sort files by modification time (newest first)
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().After(fj.ModTime())
})
// include up to logFileCount rotated files
maxFiles := int(g.logFileCount)
if maxFiles > len(files) {
maxFiles = len(files)
}
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if err := g.addSingleLogFileGz(files[i], name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,

View File

@@ -11,9 +11,10 @@ import (
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 1
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityUpstream = 50
PriorityDefault = 1
)
type SubdomainMatcher interface {

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true,
nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: true,
},
},
{
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityUpstream},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDNSRoute: true,
nbdns.PriorityUpstream: false,
},
},
{
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityUpstream},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true,
nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true,
},
},
}
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
// Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"example.com.", nbdns.PriorityUpstream, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
},
query: "test.sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
},
query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",

View File

@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain,
handler: s.localResolver,
priority: PriorityMatchDomain,
priority: PriorityLocal,
})
for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain
basePriority := PriorityUpstream
if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault
}
@@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i
// Check if we're about to overlap with the next priority tier
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
// Check if we're about to overlap with the next priority tier.
// This boundary check ensures that the priority of upstream handlers does not conflict
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
domainGroup.domain, PriorityUpstream-PriorityDefault)
break
}

View File

@@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityLocal,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone,
@@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
initSerial: 0,
@@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
initSerial: 0,
@@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
initSerial: 0,
@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
@@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
}
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
testCases := []struct {
name string
@@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
}
@@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
"upstream-other": {
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
}
@@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
},
expectedHandlers: map[string]string{
@@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
expectedHandlers: map[string]string{
@@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain + 1,
priority: PriorityUpstream + 1,
},
// Keep existing groups with their original priorities
{
@@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
},
expectedHandlers: map[string]string{
@@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
// Add group3 with lowest priority
{
@@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain - 2,
priority: PriorityUpstream - 2,
},
},
expectedHandlers: map[string]string{
@@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
expectedHandlers: map[string]string{
@@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "new.com",
handler: &mockHandler{
Id: "upstream-new",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
expectedHandlers: map[string]string{
@@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
// Verify refcount is 2
zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
// Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
}
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
}
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -2,6 +2,7 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():
log.Tracef("%s has been stopped", u)
logger.Tracef("%s has been stopped", u)
return
default:
}
@@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
}
if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
}
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
}
// count the fails only if they happen sequentially
u.failsCount.Store(0)
return
}
u.failsCount.Add(1)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
}
@@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -84,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
}
return false
}
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -36,3 +36,10 @@ func newUpstreamResolver(
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
}
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -18,14 +18,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct {
listenAddress string
ttl uint32
@@ -38,16 +44,18 @@ type DNSForwarder struct {
mutex sync.RWMutex
fwdEntries []*ForwarderEntry
firewall firewall.Manager
firewall firewaller
resolver resolver
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
}
}
@@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server
mux := dns.NewServeMux()
f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
// TCP server
tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{
Addr: f.listenAddress,
Net: "tcp",
@@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown)
return <-errCh
}
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
}
func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(w, query, resp, domain, err)
return nil
}
f.updateInternalState(domain, ips)
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
@@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
}
}
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" {
for _, ip := range ips {
var prefix netip.Prefix
@@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches
}
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,11 +1,21 @@
package dnsfwd
import (
"context"
"fmt"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -13,7 +23,7 @@ import (
func Test_getMatchingEntries(t *testing.T) {
testCases := []struct {
name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId
storedMappings map[string]route.ResID
queryDomain string
expectedResId route.ResID
}{
@@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
{
name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com",
queryDomain: "foo.example.org",
expectedResId: "",
},
{
@@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
})
}
}
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -38,7 +38,6 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
@@ -62,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -139,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@@ -175,8 +170,7 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
statusRecorder *peer.Status
firewall firewallManager.Manager
routeManager routemanager.Manager
@@ -383,7 +377,13 @@ func (e *Engine) Start() error {
}
e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer()
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
@@ -400,16 +400,13 @@ func (e *Engine) Start() error {
InitialRoutes: initialRoutes,
StateManager: e.stateManager,
DNSServer: dnsServer,
DNSFeatureFlag: dnsFeatureFlag,
PeerStore: e.peerStore,
DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes,
})
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
if err := e.routeManager.Init(); err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
}
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -451,9 +448,7 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
@@ -488,9 +483,9 @@ func (e *Engine) createFirewall() error {
}
func (e *Engine) initFirewall() error {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
return fmt.Errorf("set firewall: %w", err)
}
if e.config.BlockLANAccess {
@@ -1009,8 +1004,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1021,6 +1014,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
@@ -1255,14 +1249,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close()
conn.Close(false)
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil
}
@@ -1302,13 +1292,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1331,11 +1320,16 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
msgType := msg.GetBody().GetType()
if msgType != sProto.Body_GO_IDLE {
e.connMgr.ActivatePeer(e.ctx, conn)
}
switch msg.GetBody().Type {
case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1392,6 +1386,8 @@ func (e *Engine) receiveSignalEvents() {
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE:
case sProto.Body_GO_IDLE:
e.connMgr.DeactivatePeer(conn)
}
return nil
@@ -1489,7 +1485,12 @@ func (e *Engine) close() {
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
if runtime.GOOS != "android" {
// nolint:nilnil
return nil, nil, false, nil
}
info := system.GetInfo(e.ctx)
info.SetFlags(
e.config.RosenpassEnabled,
@@ -1506,11 +1507,12 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, err
return nil, nil, false, err
}
routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
return routes, &dnsCfg, nil
dnsFeatureFlag := toDNSFeatureFlag(netMap)
return routes, &dnsCfg, dnsFeatureFlag, nil
}
func (e *Engine) newWgIface() (*iface.WGIface, error) {
@@ -1527,6 +1529,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
}
switch runtime.GOOS {
@@ -1557,18 +1560,14 @@ func (e *Engine) wgInterfaceCreate() (err error) {
return err
}
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
// due to tests where we are using a mocked version of the DNS server
if e.dnsServer != nil {
return nil, e.dnsServer, nil
return e.dnsServer, nil
}
switch runtime.GOOS {
case "android":
routes, dnsConfig, err := e.readInitialSettings()
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx,
e.wgInterface,
@@ -1579,19 +1578,19 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
e.config.DisableDNS,
)
go e.mobileDep.DnsReadyListener.OnReady()
return routes, dnsServer, nil
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return nil, dnsServer, nil
return dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
if err != nil {
return nil, nil, err
return nil, err
}
return nil, dnsServer, nil
return dnsServer, nil
}
}

View File

@@ -36,7 +36,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -53,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
@@ -97,6 +97,7 @@ type MockWGIface struct {
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
@@ -187,6 +188,13 @@ func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}
func (m *MockWGIface) LastActivities() map[string]monotime.Time {
if m.LastActivitiesFunc != nil {
return m.LastActivitiesFunc()
}
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
@@ -392,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr,
})
_, _, err = engine.routeManager.Init()
err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -404,7 +412,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
engine.connMgr.Start(ctx)
type testCase struct {
@@ -793,7 +801,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr.Start(ctx)
defer func() {
@@ -991,7 +999,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr.Start(ctx)
defer func() {
@@ -1473,16 +1481,20 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
)
type wgIfaceBase interface {
@@ -38,4 +39,5 @@ type wgIfaceBase interface {
GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
}

View File

@@ -13,7 +13,7 @@ import (
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface lazyconn.WGIface
wgIface WgInterface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
@@ -22,7 +22,7 @@ type Listener struct {
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,
@@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener")
d.peerCfg.Log.Infof("exit from activity listener")
} else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
}
@@ -59,9 +59,11 @@ func (d *Listener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
d.peerCfg.Log.Infof("activity detected")
break
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
@@ -71,7 +73,7 @@ func (d *Listener) ReadPackets() {
}
func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
if err := d.conn.Close(); err != nil {
@@ -81,7 +83,6 @@ func (d *Listener) Close() {
}
func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}

View File

@@ -1,18 +1,27 @@
package activity
import (
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface lazyconn.WGIface
wgIface WgInterface
peers map[peerid.ConnID]*Listener
done chan struct{}
@@ -20,7 +29,7 @@ type Manager struct {
mu sync.Mutex
}
func NewManager(wgIface lazyconn.WGIface) *Manager {
func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,

View File

@@ -33,6 +33,15 @@ func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAd
}
// Add this method to the Manager struct
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
m.mu.Lock()
defer m.mu.Unlock()
listener, exists := m.peers[peerConnID]
return listener, exists
}
func TestManager_MonitorPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
@@ -51,7 +60,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
listener, exists := mgr.GetPeerListener(peerCfg1.PeerConnID)
if !exists {
t.Fatalf("peer listener not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -128,11 +142,21 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
listener, exists := mgr.GetPeerListener(peerCfg1.PeerConnID)
if !exists {
t.Fatalf("peer listener for peer1 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
listener, exists = mgr.GetPeerListener(peerCfg2.PeerConnID)
if !exists {
t.Fatalf("peer listener for peer2 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}

View File

@@ -1,70 +0,0 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}

View File

@@ -1,156 +0,0 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -0,0 +1,155 @@
package inactivity
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
)
const (
checkInterval = 1 * time.Minute
DefaultInactivityThreshold = 15 * time.Minute
MinimumInactivityThreshold = 1 * time.Minute
)
type WgInterface interface {
LastActivities() map[string]monotime.Time
}
type Manager struct {
inactivePeersChan chan map[string]struct{}
iface WgInterface
interestedPeers map[string]*lazyconn.PeerConfig
inactivityThreshold time.Duration
}
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
if err != nil {
inactivityThreshold = DefaultInactivityThreshold
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
}
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
return &Manager{
inactivePeersChan: make(chan map[string]struct{}, 1),
iface: iface,
interestedPeers: make(map[string]*lazyconn.PeerConfig),
inactivityThreshold: inactivityThreshold,
}
}
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
if m == nil {
// return a nil channel that blocks forever
return nil
}
return m.inactivePeersChan
}
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
if m == nil {
return
}
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
return
}
peerCfg.Log.Infof("adding peer to inactivity manager")
m.interestedPeers[peerCfg.PublicKey] = peerCfg
}
func (m *Manager) RemovePeer(peer string) {
if m == nil {
return
}
pi, ok := m.interestedPeers[peer]
if !ok {
return
}
pi.Log.Debugf("remove peer from inactivity manager")
delete(m.interestedPeers, peer)
}
func (m *Manager) Start(ctx context.Context) {
if m == nil {
return
}
ticker := newTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C():
idlePeers, err := m.checkStats()
if err != nil {
log.Errorf("error checking stats: %v", err)
return
}
if len(idlePeers) == 0 {
continue
}
m.notifyInactivePeers(ctx, idlePeers)
}
}
}
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
select {
case m.inactivePeersChan <- inactivePeers:
case <-ctx.Done():
return
default:
return
}
}
func (m *Manager) checkStats() (map[string]struct{}, error) {
lastActivities := m.iface.LastActivities()
idlePeers := make(map[string]struct{})
checkTime := time.Now()
for peerID, peerCfg := range m.interestedPeers {
lastActive, ok := lastActivities[peerID]
if !ok {
// when peer is in connecting state
peerCfg.Log.Warnf("peer not found in wg stats")
continue
}
since := monotime.Since(lastActive)
if since > m.inactivityThreshold {
peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String())
idlePeers[peerID] = struct{}{}
}
}
return idlePeers, nil
}
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
if configuredThreshold == nil {
return DefaultInactivityThreshold, nil
}
if *configuredThreshold < MinimumInactivityThreshold {
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
}
return *configuredThreshold, nil
}

View File

@@ -0,0 +1,114 @@
package inactivity
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
)
type mockWgInterface struct {
lastActivities map[string]monotime.Time
}
func (m *mockWgInterface) LastActivities() map[string]monotime.Time {
return m.lastActivities
}
func TestPeerTriggersInactivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]monotime.Time{
peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case inactivePeers := <-manager.inactivePeersChan:
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
case <-time.After(1 * time.Second):
t.Fatal("expected inactivity event, but none received")
}
}
func TestPeerTriggersActivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]monotime.Time{
peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case <-manager.inactivePeersChan:
t.Fatal("expected inactive peer to be marked inactive")
case <-time.After(1 * time.Second):
// No inactivity event should be received
}
}
// fakeTickerMock implements Ticker interface for testing
type fakeTickerMock struct {
CChan chan time.Time
}
func (f *fakeTickerMock) C() <-chan time.Time {
return f.CChan
}
func (f *fakeTickerMock) Stop() {}

View File

@@ -0,0 +1,24 @@
package inactivity
import "time"
var newTicker = func(d time.Duration) Ticker {
return &realTicker{t: time.NewTicker(d)}
}
type Ticker interface {
C() <-chan time.Time
Stop()
}
type realTicker struct {
t *time.Ticker
}
func (r *realTicker) C() <-chan time.Time {
return r.t.C
}
func (r *realTicker) Stop() {
r.t.Stop()
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
@@ -43,60 +42,46 @@ type Config struct {
type Manager struct {
engineCtx context.Context
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
activityManager *activity.Manager
inactivityManager *inactivity.Manager
// Route HA group management
// If any peer in the same HA group is active, all peers in that group should prevent going idle
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex // protects route mappings
onInactive chan peerid.ConnID
routesMu sync.RWMutex
}
// NewManager creates a new lazy connection manager
// engineCtx is the context for creating peer Connection
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
engineCtx: engineCtx,
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
peerToHAGroups: make(map[string][]route.HAUniqueID),
haGroupToPeers: make(map[route.HAUniqueID][]string),
onInactive: make(chan peerid.ConnID),
}
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
if wgIface.IsUserspaceBind() {
m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
} else {
log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
@@ -131,24 +116,28 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
}
}
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
len(m.haGroupToPeers), len(m.peerToHAGroups))
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
if m.inactivityManager != nil {
go m.inactivityManager.Start(ctx)
}
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
m.onPeerActivity(peerConnID)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
}
}
}
// ExcludePeer marks peers for a permanent connection
@@ -156,7 +145,7 @@ func (m *Manager) Start(ctx context.Context) {
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -187,7 +176,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(ctx, peerCfg); err != nil {
if err := m.addActivePeer(&peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
@@ -217,20 +206,24 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(peerCfg)
}
return false, nil
}
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -240,7 +233,7 @@ func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerCon
continue
}
if err := m.addActivePeer(ctx, cfg); err != nil {
if err := m.addActivePeer(&cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
@@ -257,7 +250,7 @@ func (m *Manager) RemovePeer(peerID string) {
// ActivatePeer activates a peer connection when a signal message is received
// Also activates all peers in the same HA groups as this peer
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
func (m *Manager) ActivatePeer(peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, mp := m.getPeerForActivation(peerID)
@@ -265,15 +258,43 @@ func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool)
return false
}
if !m.activateSinglePeer(ctx, cfg, mp) {
cfg.Log.Infof("activate peer from inactive state by remote signal message")
if !m.activateSinglePeer(cfg, mp) {
return false
}
m.activateHAGroupPeers(ctx, peerID)
m.activateHAGroupPeers(cfg)
return true
}
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
// Returns nil values if the peer should be skipped
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
@@ -295,82 +316,120 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
return cfg, mp
}
// activateSinglePeer activates a single peer (internal method)
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
// activateSinglePeer activates a single peer
// return true if the peer was activated, false if it was already active
func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
if mp.expectedWatcher == watcherInactivity {
return false
}
cfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
m.inactivityManager.AddPeer(cfg)
return true
}
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
var peersToActivate []string
m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
if len(haGroups) == 0 {
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
m.routesMu.RUnlock()
triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
return
}
activatedCount := 0
for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
for _, peerID := range peers {
if peerID == triggerPeerID {
continue
if peerID != triggeredPeerCfg.PublicKey {
peersToActivate = append(peersToActivate, peerID)
}
}
}
m.routesMu.RUnlock()
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
activatedCount := 0
for _, peerID := range peersToActivate {
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
if m.activateSinglePeer(cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
}
if activatedCount > 0 {
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
activatedCount, triggerPeerID, haGroups)
activatedCount, triggeredPeerCfg.PublicKey, haGroups)
}
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(&peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeers[peerCfg.PublicKey] = peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
peerCfg: peerCfg,
expectedWatcher: watcherInactivity,
}
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
m.inactivityManager.AddPeer(peerCfg)
return nil
}
@@ -382,12 +441,7 @@ func (m *Manager) removePeer(peerID string) {
cfg.Log.Infof("removing lazy peer")
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.inactivityManager.RemovePeer(cfg.PublicKey)
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
@@ -397,12 +451,8 @@ func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
@@ -415,7 +465,56 @@ func (m *Manager) close() {
log.Infof("lazy connection manager closed")
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
return true
}
}
return false
}
func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// If any peer in the group is active, do defer idle
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
return true
}
}
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -432,89 +531,56 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
mp.peerCfg.Log.Infof("detected peer activity")
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
if !m.activateSinglePeer(mp.peerCfg, mp) {
return
}
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
for peerID := range peerIDs {
peerCfg, ok := m.managedPeers[peerID]
if !ok {
log.Errorf("peer not found by peerId: %v", peerID)
continue
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
continue
}
mp.peerCfg.Log.Infof("connection timed out")
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
continue
}
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
continue
}
mp.peerCfg.Log.Infof("start activity monitor")
mp.peerCfg.Log.Infof("connection timed out")
mp.expectedWatcher = watcherActivity
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
mp.expectedWatcher = watcherActivity
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
continue
}
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

Some files were not shown because too many files have changed in this diff Show More