Compare commits

..

83 Commits

Author SHA1 Message Date
snyk-bot
89064bb5d5 fix: management/Dockerfile to reduce vulnerabilities
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-UBUNTU2404-TAR-10769052
- https://snyk.io/vuln/SNYK-UBUNTU2404-GLIBC-10321975
- https://snyk.io/vuln/SNYK-UBUNTU2404-GLIBC-10321975
- https://snyk.io/vuln/SNYK-UBUNTU2404-PAM-8303372
- https://snyk.io/vuln/SNYK-UBUNTU2404-PAM-8352843
2025-08-05 03:30:09 +00:00
Pascal Fischer
d1e0b7f4fb [management] get peer groups without lock (#4280) 2025-08-05 01:11:44 +02:00
Viktor Liu
beb66208a0 [management, client] Add API to change the network range (#4177) 2025-08-04 16:45:49 +02:00
Viktor Liu
58eb3c8cc2 [client] Increase ip rule priorities to avoid conflicts (#4273) 2025-08-04 11:20:43 +02:00
Viktor Liu
b5ed94808c [management, client] Add logout feature (#4268) 2025-08-04 10:17:36 +02:00
Pascal Fischer
552dc60547 [management] migrate group peers into seperate table (#4096) 2025-08-01 12:22:07 +02:00
Viktor Liu
71bb09d870 [client] Improve userspace filter logging performance (#4221) 2025-07-31 14:36:30 +02:00
Viktor Liu
5de61f3081 [client] Fix dns ipv6 upstream (#4257) 2025-07-30 20:28:19 +02:00
Vlad
541e258639 [management] add account deleted event (#4255) 2025-07-30 17:49:50 +03:00
Bilgeworth
34042b8171 [misc] devcontainer Dockerfile: pin gopls to v0.18.1 (latest that supports golang 1.23) (#4240)
Container will fail to build with newer versions of gopls unless golang is updated to 1.24. The latest stable version supporting 1.23 is gopls v0.18.1
2025-07-29 20:52:18 +02:00
hakansa
a72ef1af39 [client] Fix error handling for set config request on CLI (#4237)
[client] Fix error handling for set config request on CLI (#4237)
2025-07-29 20:38:44 +03:00
Viktor Liu
980a6eca8e [client] Disable the dns host manager properly if disabled through management (#4241) 2025-07-29 19:37:18 +02:00
hakansa
8c8473aed3 [client] Add support for disabling profiles feature via command line flag (#4235)
* Add support for disabling profiles feature via command line flag

* Add profiles disabling flag to service command

* Refactor profile menu initialization and enhance error notifications in event handlers
2025-07-29 13:03:15 +03:00
hakansa
e1c66a8124 [client] Fix profile directory path handling based on NB_STATE_DIR (#4229)
[client] Fix profile directory path handling based on NB_STATE_DIR (#4229)
2025-07-28 13:36:48 +03:00
Zoltan Papp
d89e6151a4 [client] Fix pre-shared key state in wg show (#4222) 2025-07-25 22:52:48 +02:00
hakansa
3d9be5098b [client]: deprecate config flag (#4224) 2025-07-25 18:43:48 +03:00
hakansa
cb8b6ca59b [client] Feat: Support Multiple Profiles (#3980)
[client] Feat: Support Multiple Profiles (#3980)
2025-07-25 16:54:46 +03:00
Viktor Liu
e0d9306b05 [client] Add detailed routes and resolved IPs to debug bundle (#4141) 2025-07-25 15:31:06 +02:00
Viktor Liu
2c4ac33b38 [client] Remove and deprecate the admin url functionality (#4218) 2025-07-25 15:15:38 +02:00
Zoltan Papp
31872a7fb6 [client] Fix UDP proxy to notify listener when remote conn closed (#4199)
* Fix UDP proxy to notify listener when remote conn closed

* Fix sender tests to use t.Errorf for timeout assertions

* Fix potential nil pointer
2025-07-25 14:14:45 +02:00
Viktor Liu
cb85d3f2fc [client] Always register NetBird with plain Linux DNS and use original servers as upstream (#3967) 2025-07-25 11:46:04 +02:00
Krzysztof Nazarewski (kdn)
af8687579b client: container: support CLI with entrypoint addition (#4126)
This will allow running netbird commands (including debugging) against the daemon and provide a flow similar to non-container usages.

It will by default both log to file and stderr so it can be handled more uniformly in container-native environments.
2025-07-25 11:44:30 +02:00
Louis Li
3f82698089 [client] make ICE failed timeout configurable (#4211) 2025-07-25 10:36:11 +02:00
Pascal Fischer
cb1e437785 [client] handle order of check when checking order of files in isChecksEqual (#4219) 2025-07-24 21:00:51 +02:00
Pascal Fischer
c435c2727f [management] Log BufferUpdateAccountPeers caller (#4217) 2025-07-24 18:33:58 +02:00
Ali Amer
643730f770 [client] Correct minor issues in --filter-by-connection-type flag implementation for status command (#4214)
Signed-off-by: aliamerj <aliamer19ali@gmail.com>
2025-07-24 17:51:27 +02:00
Pascal Fischer
04fae00a6c [management] Log UpdateAccountPeers caller (#4216) 2025-07-24 17:44:48 +02:00
Pedro Maia Costa
1a9ea32c21 [management] scheduler cancel all jobs (#4158) 2025-07-24 16:25:21 +01:00
Pedro Maia Costa
0ea5d020a3 [management] extra settings integrated validator (#4136) 2025-07-24 16:12:29 +01:00
Viktor Liu
459c9ef317 [client] Add env and status flags for netbird service command (#3975) 2025-07-24 13:34:55 +02:00
Viktor Liu
e5e275c87a [client] Fix legacy routing exclusion routes in kernel mode (#4167) 2025-07-24 13:34:36 +02:00
Zoltan Papp
d311f57559 [ci] Temporarily disable race detection in Relay (#4210) 2025-07-24 13:14:49 +02:00
Zoltan Papp
1a28d18cde [client] Fix race issues in lazy tests (#4181)
* Fix race issues in lazy tests

* Fix test failure due to incorrect peer listener identification
2025-07-23 21:03:29 +02:00
Philippe Vaucher
91e7423989 [misc] Docker compose improvements (#4037)
* Use container defaults

* Remove docker compose version when generating zitadel config
2025-07-22 19:44:49 +02:00
Zoltan Papp
86c16cf651 [server, relay] Fix/relay race disconnection (#4174)
Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
2025-07-21 19:58:17 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
Zoltan Papp
0dab03252c [client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
2025-07-15 10:43:42 +02:00
iisteev
e49bcc343d [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145)
Signed-off-by: iisteev <isteevan.shetoo@is-info.fr>
2025-07-13 15:42:48 +02:00
Zoltan Papp
3e6eede152 [client] Fix elapsed time calculation when machine is in sleep mode (#4140) 2025-07-12 11:10:45 +02:00
Vlad
a76c8eafb4 [management] sync calls to UpdateAccountPeers from BufferUpdateAccountPeers (#4137)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-11 12:37:14 +03:00
Pedro Maia Costa
2b9f331980 always suffix ephemeral peer name (#4138) 2025-07-11 10:29:10 +01:00
Viktor Liu
a7ea881900 [client] Add rotated logs flag for debug bundle generation (#4100) 2025-07-10 16:13:53 +02:00
Vlad
8632dd15f1 [management] added cleanupWindow for collecting several ephemeral peers to delete (#4130)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-10 15:21:01 +02:00
Zoltan Papp
e3b40ba694 Update cli description of lazy connection (#4133) 2025-07-10 15:00:58 +02:00
Zoltan Papp
e59d75d56a Nil check in iface configurer (#4132) 2025-07-10 14:24:20 +02:00
Zoltan Papp
408f423adc [client] Disable pidfd check on Android 11 and below (#4127)
Disable pidfd check on Android 11 and below

On Android 11 (SDK <= 30) and earlier, pidfd-related system calls
are blocked by seccomp policies, causing SIGSYS crashes.

This change overrides `checkPidfdOnce` to return an error on
affected versions, preventing the use of unsupported pidfd features.
2025-07-09 22:16:08 +02:00
Misha Bragin
f17dd3619c [misc] update image in README.md (#4122) 2025-07-09 15:49:09 +02:00
Bethuel Mmbaga
969f1ed59a [management] Remove deleted user peers from groups on user deletion (#4121)
Refactors peer deletion to centralize group cleanup logic, ensuring deleted peers are consistently removed from all groups in one place.

- Removed redundant group removal code from DefaultAccountManager.DeletePeer
- Added group removal logic inside deletePeers to handle both single and multiple peer deletions
2025-07-09 10:14:10 +03:00
M. Essam
768ba24fda [management,rest] Add name/ip filters to peer management rest client (#4112) 2025-07-08 18:08:13 +02:00
Zoltan Papp
8942c40fde [client] Fix nil pointer exception in lazy connection (#4109)
Remove unused variable
2025-07-06 15:13:14 +02:00
Zoltan Papp
fbb1b55beb [client] refactor lazy detection (#4050)
This PR introduces a new inactivity package responsible for monitoring peer activity and notifying when peers become inactive.
Introduces a new Signal message type to close the peer connection after the idle timeout is reached.
Periodically checks the last activity of registered peers via a Bind interface.
Notifies via a channel when peers exceed a configurable inactivity threshold.
Default settings
DefaultInactivityThreshold is set to 15 minutes, with a minimum allowed threshold of 1 minute.

Limitations
This inactivity check does not support kernel WireGuard integration. In kernel–user space communication, the user space side will always be responsible for closing the connection.
2025-07-04 19:52:27 +02:00
Viktor Liu
77ec32dd6f [client] Implement dns routes for Android (#3989) 2025-07-04 16:43:11 +02:00
Bethuel Mmbaga
8c09a55057 [management] Log user id on account mismatch (#4101) 2025-07-04 10:51:58 +03:00
Pedro Maia Costa
f603ddf35e management: fix store get account peers without lock (#4092) 2025-07-04 08:44:08 +01:00
Krzysztof Nazarewski (kdn)
996b8c600c [management] replace invalid user with a clear error message about mismatched logins (#4097) 2025-07-03 16:36:36 +02:00
Maycon Santos
c4ed11d447 [client] Avoid logging setup keys on error message (#3962) 2025-07-03 16:22:18 +02:00
Viktor Liu
9afbecb7ac [client] Use unique sequence numbers for bsd routes (#4081)
updates the route manager on Unix to use a unique, incrementing sequence number for each route message instead of a fixed value.

Replace the static Seq: 1 with a call to r.getSeq()
Add an atomic seq field and the getSeq method in SysOps
2025-07-03 09:02:53 +02:00
Maycon Santos
2c81cf2c1e [management] Add account onboarding (#4084)
This PR introduces a new onboarding feature to handle such flows in the dashboard by defining an AccountOnboarding model, persisting it in the store, exposing CRUD operations in the manager and HTTP handlers, and updating API schemas and tests accordingly.

Add AccountOnboarding struct and embed it in Account
Extend Store and DefaultAccountManager with onboarding methods and SQL migrations
Update HTTP handlers, API types, OpenAPI spec, and add end-to-end tests
2025-07-03 09:01:32 +02:00
Pascal Fischer
551cb4e467 [management] expect specific error types on registration with setup key (#4094) 2025-07-02 20:04:28 +02:00
Maycon Santos
57961afe95 [doc] Add forum link (#4093)
* Add forum link

* Add forum link
2025-07-02 18:40:07 +02:00
Pascal Fischer
22678bce7f [management] add uniqueness constraint for peer ip and label and optimize generation (#4042) 2025-07-02 18:13:10 +02:00
Maycon Santos
6c633497bc [management] fix network update test for delete policy (#4086)
when adding a peer we calculate the network map an account using backpressure functions and some updates might arrive around the time we are deleting a policy.

This change ensures we wait enough time for the updates from add peer to be sent and read before continuing with the test logic
2025-07-02 12:25:31 +02:00
Carlos Hernandez
6922826919 [client] Support fullstatus without probes (#4052) 2025-07-02 10:42:47 +02:00
Maycon Santos
56a1a75e3f [client] Support random wireguard port on client (#4085)
Adds support for using a random available WireGuard port when the user specifies port `0`.

- Updates `freePort` logic to bind to the requested port (including `0`) without falling back to the default.
- Removes default port assignment in the configuration path, allowing `0` to propagate.
- Adjusts tests to handle dynamically assigned ports when using `0`.
2025-07-02 09:01:02 +02:00
Ali Amer
d9402168ad [management] Add option to disable default all-to-all policy (#3970)
This PR introduces a new configuration option `DisableDefaultPolicy` that prevents the creation of the default all-to-all policy when new accounts are created. This is useful for automation scenarios where explicit policies are preferred.
### Key Changes:
- Added DisableDefaultPolicy flag to the management server config
- Modified account creation logic to respect this flag
- Updated all test cases to explicitly pass the flag (defaulting to false to maintain backward compatibility)
- Propagated the flag through the account manager initialization chain

### Testing:

- Verified default behavior remains unchanged when flag is false
- Confirmed no default policy is created when flag is true
- All existing tests pass with the new parameter
2025-07-02 02:41:59 +02:00
Krzysztof Nazarewski (kdn)
dbdef04b9e [misc] getting-started-with-zitadel.sh: drop unnecessary port 8080 (#4075) 2025-07-02 02:35:13 +02:00
Maycon Santos
29cbfe8467 [misc] update sign pipeline version to v0.0.20 (#4082) 2025-07-01 16:23:31 +02:00
Maycon Santos
6ce8643368 [client] Run login popup on goroutine (#4080) 2025-07-01 13:45:55 +02:00
Krzysztof Nazarewski (kdn)
07d1ad35fc [misc] start the service after installation on arch linux (#4071) 2025-06-30 12:02:03 +02:00
Krzysztof Nazarewski (kdn)
ef6cd36f1a [misc] fix arch install.sh error with empty temporary dependencies
handle empty var before calling removal command
2025-06-30 11:59:35 +02:00
Krzysztof Nazarewski (kdn)
c1c71b6d39 [client] improve adding route log message (#4034)
from:
  Adding route to 1.2.3.4/32 via invalid IP @ 10 (wt0)
to:
  Adding route to 1.2.3.4/32 via no-ip @ 10 (wt0)
2025-06-30 11:57:42 +02:00
Pascal Fischer
0480507a10 [management] report networkmap duration in ms (#4064) 2025-06-28 11:38:15 +02:00
Krzysztof Nazarewski (kdn)
34ac4e4b5a [misc] fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG (#4055)
* fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG

fixes https://github.com/netbirdio/netbird/issues/4054

* un-quote the number

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-06-26 10:45:00 +02:00
Pascal Fischer
52ff9d9602 [management] remove unused transaction (#4053) 2025-06-26 01:34:22 +02:00
Pascal Fischer
1b73fae46e [management] add breakdown of network map calculation metrics (#4020) 2025-06-25 11:46:35 +02:00
334 changed files with 18214 additions and 6507 deletions

View File

@@ -9,7 +9,7 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@latest
&& go install -v golang.org/x/tools/gopls@v0.18.1
WORKDIR /app

3
.dockerignore-client Normal file
View File

@@ -0,0 +1,3 @@
*
!client/netbird-entrypoint.sh
!netbird

View File

@@ -16,6 +16,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
- uses: git-town/action@v1.2.1
with:
skip-single-stacks: true
skip-single-stacks: true

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -24,8 +24,8 @@ jobs:
id: filter
with:
filters: |
management:
- 'management/**'
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
@@ -148,7 +148,7 @@ jobs:
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -181,6 +181,7 @@ jobs:
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
CONTAINER: "true"
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
@@ -198,6 +199,7 @@ jobs:
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
@@ -211,7 +213,11 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: ""
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -251,9 +257,9 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m ./signal/...
-timeout 10m ./relay/...
test_signal:
name: "Signal / Unit"

View File

@@ -43,7 +43,7 @@ jobs:
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.19"
SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -134,6 +134,7 @@ jobs:
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
run: |
set -x
@@ -180,6 +181,7 @@ jobs:
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
- name: Install modules
run: go mod tidy

1
.gitignore vendored
View File

@@ -30,3 +30,4 @@ infrastructure_files/setup-*.env
.vscode
.DS_Store
vendor/
/netbird

View File

@@ -155,13 +155,15 @@ dockers:
goarch: amd64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
@@ -171,6 +173,8 @@ dockers:
goarch: arm64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
@@ -188,6 +192,8 @@ dockers:
goarm: 6
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
@@ -205,6 +211,8 @@ dockers:
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
@@ -221,6 +229,8 @@ dockers:
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
@@ -238,6 +248,8 @@ dockers:
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"

View File

@@ -14,6 +14,9 @@
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
@@ -29,13 +32,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/>
</strong>
<br>
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
</p>
@@ -47,10 +50,9 @@
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open-Source Network Security in a Single Platform
### Open Source Network Security in a Single Platform
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -1,9 +1,27 @@
FROM alpine:3.21.3
# build & run locally with:
# cd "$(git rev-parse --show-toplevel)"
# CGO_ENABLED=0 go build -o netbird ./client
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
RUN apk add --no-cache \
bash \
ca-certificates \
ip6tables \
iproute2 \
iptables
ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -1,18 +1,33 @@
FROM alpine:3.21.0
# build & run locally with:
# cd "$(git rev-parse --show-toplevel)"
# CGO_ENABLED=0 go build -o netbird ./client
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
FROM alpine:3.22.0
RUN apk add --no-cache ca-certificates \
RUN apk add --no-cache \
bash \
ca-certificates \
&& adduser -D -h /var/lib/netbird netbird
WORKDIR /var/lib/netbird
USER netbird:netbird
ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true
ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_USE_NETSTACK_MODE="true" \
NB_ENABLE_NETSTACK_LOCAL_FORWARDING="true" \
NB_CONFIG="/var/lib/netbird/config.json" \
NB_STATE_DIR="/var/lib/netbird" \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
@@ -64,7 +65,9 @@ type Client struct {
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
@@ -80,7 +83,7 @@ func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapt
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
@@ -115,7 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
@@ -203,8 +206,10 @@ func (c *Client) Networks() *NetworkArray {
continue
}
if routes[0].IsDynamic() {
continue
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
@@ -214,7 +219,7 @@ func (c *Client) Networks() *NetworkArray {
}
network := Network{
Name: string(id),
Network: routes[0].Network.String(),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}

26
client/android/exec.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build android
package android
import (
"fmt"
_ "unsafe"
)
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
// In Android version 11 and earlier, pidfd-related system calls
// are not allowed by the seccomp policy, which causes crashes due
// to SIGSYS signals.
//go:linkname checkPidfdOnce os.checkPidfdOnce
var checkPidfdOnce func() error
func execWorkaround(androidSDKVersion int) {
if androidSDKVersion > 30 { // above Android 11
return
}
checkPidfdOnce = func() error {
return fmt.Errorf("unsupported Android version")
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -37,17 +38,17 @@ type URLOpener interface {
// Auth can register or login new client
type Auth struct {
ctx context.Context
config *internal.Config
config *profilemanager.Config
cfgPath string
}
// NewAuth instantiate Auth struct and validate the management URL
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
inputCfg := internal.ConfigInput{
inputCfg := profilemanager.ConfigInput{
ManagementURL: mgmURL,
}
cfg, err := internal.CreateInMemoryConfig(inputCfg)
cfg, err := profilemanager.CreateInMemoryConfig(inputCfg)
if err != nil {
return nil, err
}
@@ -60,7 +61,7 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
}
// NewAuthWithConfig instantiate Auth based on existing config
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
return &Auth{
ctx: ctx,
config: config,
@@ -110,7 +111,7 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = internal.WriteOutConfig(a.cfgPath, a.config)
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -142,7 +143,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
return fmt.Errorf("backoff cycle failed: %v", err)
}
return internal.WriteOutConfig(a.cfgPath, a.config)
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
}
// Login try register the client on the server

View File

@@ -1,17 +1,17 @@
package android
import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
// Preferences exports a subset of the internal config for gomobile
type Preferences struct {
configInput internal.ConfigInput
configInput profilemanager.ConfigInput
}
// NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{
ci := profilemanager.ConfigInput{
ConfigPath: configPath,
}
return &Preferences{ci}
@@ -23,7 +23,7 @@ func (p *Preferences) GetManagementURL() (string, error) {
return p.configInput.ManagementURL, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
@@ -41,7 +41,7 @@ func (p *Preferences) GetAdminURL() (string, error) {
return p.configInput.AdminURL, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
@@ -59,7 +59,7 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return *p.configInput.PreSharedKey, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
@@ -82,7 +82,7 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -100,7 +100,7 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -113,7 +113,7 @@ func (p *Preferences) GetDisableClientRoutes() (bool, error) {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -131,7 +131,7 @@ func (p *Preferences) GetDisableServerRoutes() (bool, error) {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -149,7 +149,7 @@ func (p *Preferences) GetDisableDNS() (bool, error) {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -167,7 +167,7 @@ func (p *Preferences) GetDisableFirewall() (bool, error) {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -185,7 +185,7 @@ func (p *Preferences) GetServerSSHAllowed() (bool, error) {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -207,7 +207,7 @@ func (p *Preferences) GetBlockInbound() (bool, error) {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
@@ -221,6 +221,6 @@ func (p *Preferences) SetBlockInbound(block bool) {
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
return err
}

View File

@@ -4,7 +4,7 @@ import (
"path/filepath"
"testing"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func TestPreferences_DefaultValues(t *testing.T) {
@@ -15,7 +15,7 @@ func TestPreferences_DefaultValues(t *testing.T) {
t.Fatalf("failed to read default value: %s", err)
}
if defaultVar != internal.DefaultAdminURL {
if defaultVar != profilemanager.DefaultAdminURL {
t.Errorf("invalid default admin url: %s", defaultVar)
}
@@ -24,7 +24,7 @@ func TestPreferences_DefaultValues(t *testing.T) {
t.Fatalf("failed to read default management URL: %s", err)
}
if defaultVar != internal.DefaultManagementURL {
if defaultVar != profilemanager.DefaultManagementURL {
t.Errorf("invalid default management url: %s", defaultVar)
}

View File

@@ -13,14 +13,23 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const errCloseConnection = "Failed to close connection: %v"
var (
logFileCount uint32
systemInfoFlag bool
uploadBundleFlag bool
uploadBundleURLFlag string
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -88,12 +97,13 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -105,7 +115,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -223,12 +233,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -255,7 +266,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -297,7 +308,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
)
}
return statusOutputString
@@ -345,7 +356,7 @@ func formatDuration(d time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
@@ -375,3 +386,15 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
config *profilemanager.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
@@ -28,7 +29,7 @@ const (
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
config *profilemanager.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
@@ -83,7 +84,7 @@ func SetupDebugHandler(
func waitForEvent(
ctx context.Context,
config *internal.Config,
config *profilemanager.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,

View File

@@ -20,7 +20,7 @@ var downCmd = &cobra.Command{
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
log.Errorf("failed initializing log %v", err)
return err

View File

@@ -4,10 +4,12 @@ import (
"context"
"fmt"
"os"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
@@ -15,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
@@ -22,19 +25,16 @@ import (
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
}
var loginCmd = &cobra.Command{
Use: "login",
Short: "login to the Netbird Management Service (first run)",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err)
}
ctx := internal.CtxInitState(context.Background())
@@ -43,6 +43,17 @@ var loginCmd = &cobra.Command{
// nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
username, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
pm := profilemanager.NewProfileManager()
activeProf, err := getActiveProfile(cmd.Context(), pm, profileName, username.Username)
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
providedSetupKey, err := getSetupKey()
if err != nil {
@@ -50,97 +61,15 @@ var loginCmd = &cobra.Command{
}
// workaround to run without service
if logFile == "console" {
err = handleRebrand(cmd)
if err != nil {
return err
}
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
if util.FindFirstLogPath(logFiles) == "" {
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
cmd.Println("Logging successfully")
return nil
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
if err := doDaemonLogin(ctx, cmd, providedSetupKey, activeProf, username.Username, pm); err != nil {
return fmt.Errorf("daemon login failed: %v", err)
}
cmd.Println("Logging successfully")
@@ -149,7 +78,196 @@ var loginCmd = &cobra.Command{
},
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &activeProf.Name,
Username: &username,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil {
return fmt.Errorf("sso login failed: %v", err)
}
}
return nil
}
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
// switch profile if provided
if profileName != "" {
if err := switchProfileOnDaemon(ctx, pm, profileName, username); err != nil {
return nil, fmt.Errorf("switch profile: %v", err)
}
}
activeProf, err := pm.GetActiveProfile()
if err != nil {
return nil, fmt.Errorf("get active profile: %v", err)
}
if activeProf == nil {
return nil, fmt.Errorf("active profile not found, please run 'netbird profile create' first")
}
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
log.Errorf("failed to connect to service CLI interface %v", err)
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
log.Errorf("call service down method: %v", err)
return err
}
}
return nil
}
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &username,
})
if err != nil {
return fmt.Errorf("switch profile failed: %v", err)
}
return nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
err := handleRebrand(cmd)
if err != nil {
return err
}
// update host's static platform and system information
system.UpdateStaticInfo()
configFilePath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile file path: %v", err)
}
config, err := profilemanager.ReadConfig(configFilePath)
if err != nil {
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
cmd.Println("Logging successfully")
return nil
}
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
if resp.Email != "" {
err = pm.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email,
})
if err != nil {
log.Warnf("failed to set active profile email: %v", err)
}
}
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
needsLogin := false
err := WithBackOff(func() error {
@@ -195,7 +313,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil {
return nil, err
@@ -251,3 +369,16 @@ func isUnixRunningDesktop() bool {
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func setEnvAndFlags(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
return nil
}

View File

@@ -2,11 +2,11 @@ package cmd
import (
"fmt"
"os/user"
"strings"
"testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/util"
)
@@ -14,40 +14,41 @@ func TestLogin(t *testing.T) {
mgmAddr := startTestingServices(t)
tempDir := t.TempDir()
confPath := tempDir + "/config.json"
currUser, err := user.Current()
if err != nil {
t.Fatalf("failed to get current user: %v", err)
return
}
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
})
mgmtURL := fmt.Sprintf("http://%s", mgmAddr)
rootCmd.SetArgs([]string{
"login",
"--config",
confPath,
"--log-file",
"console",
util.LogConsole,
"--setup-key",
strings.ToUpper("a2c8e62b-38f5-4553-b31e-dd66c696cebb"),
"--management-url",
mgmtURL,
})
err := rootCmd.Execute()
if err != nil {
t.Fatal(err)
}
// validate generated config
actualConf := &internal.Config{}
_, err = util.ReadJson(confPath, actualConf)
if err != nil {
t.Errorf("expected proper config file written, got broken %v", err)
}
if actualConf.ManagementURL.String() != mgmtURL {
t.Errorf("expected management URL %s got %s", mgmtURL, actualConf.ManagementURL.String())
}
if actualConf.WgIface != iface.WgInterfaceDefault {
t.Errorf("expected WgIfaceName %s got %s", iface.WgInterfaceDefault, actualConf.WgIface)
}
if len(actualConf.PrivateKey) == 0 {
t.Errorf("expected non empty Private key, got empty")
}
// TODO(hakan): fix this test
_ = rootCmd.Execute()
}

57
client/cmd/logout.go Normal file
View File

@@ -0,0 +1,57 @@
package cmd
import (
"context"
"fmt"
"os/user"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
var logoutCmd = &cobra.Command{
Use: "logout",
Short: "logout from the Netbird Management Service and delete peer",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %v", err)
}
defer conn.Close()
daemonClient := proto.NewDaemonServiceClient(conn)
req := &proto.LogoutRequest{}
if profileName != "" {
req.ProfileName = &profileName
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
username := currUser.Username
req.Username = &username
}
if _, err := daemonClient.Logout(ctx, req); err != nil {
return fmt.Errorf("logout: %v", err)
}
cmd.Println("Logged out successfully")
return nil
},
}
func init() {
logoutCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
}

236
client/cmd/profile.go Normal file
View File

@@ -0,0 +1,236 @@
package cmd
import (
"context"
"fmt"
"os/user"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var profileCmd = &cobra.Command{
Use: "profile",
Short: "manage Netbird profiles",
Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`,
}
var profileListCmd = &cobra.Command{
Use: "list",
Short: "list all profiles",
Long: `List all available profiles in the Netbird client.`,
Aliases: []string{"ls"},
RunE: listProfilesFunc,
}
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "add a new profile",
Long: `Add a new profile to the Netbird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>",
Short: "remove a profile",
Long: `Remove a profile from the Netbird client. The profile must not be active.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>",
Short: "select a profile",
Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return err
}
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
}
return nil
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
}
cmd.Println("Profile added successfully:", profileName)
return nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
}
cmd.Println("Profile removed successfully:", profileName)
return nil
}
func selectProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
profileManager := profilemanager.NewProfileManager()
profileName := args[0]
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return fmt.Errorf("list profiles: %w", err)
}
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
return err
}
status, err := daemonClient.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("get service status: %w", err)
}
if status.Status == string(internal.StatusConnected) {
if _, err := daemonClient.Down(ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("call service down method: %w", err)
}
}
cmd.Println("Profile switched successfully to:", profileName)
return nil
}

View File

@@ -10,6 +10,7 @@ import (
"os/signal"
"path"
"runtime"
"slices"
"strings"
"syscall"
"time"
@@ -21,8 +22,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
@@ -38,14 +38,10 @@ const (
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
)
var (
configPath string
defaultConfigPathDir string
defaultConfigPath string
oldDefaultConfigPathDir string
@@ -55,7 +51,7 @@ var (
defaultLogFile string
oldDefaultLogFileDir string
oldDefaultLogFile string
logFile string
logFiles []string
daemonAddr string
managementURL string
adminURL string
@@ -71,15 +67,12 @@ var (
interfaceName string
wireguardPort uint16
networkMonitor bool
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool
profilesDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -123,38 +116,30 @@ func init() {
defaultDaemonAddr = "tcp://127.0.0.1:41731"
}
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "(DEPRECATED) Netbird config file location")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
rootCmd.AddCommand(statusCmd)
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
rootCmd.AddCommand(profileCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -167,6 +152,12 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)
profileCmd.AddCommand(profileRemoveCmd)
profileCmd.AddCommand(profileSelectCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -184,11 +175,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success
@@ -196,14 +184,13 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
termCh := make(chan os.Signal, 1)
signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
go func() {
done := ctx.Done()
defer cancel()
select {
case <-done:
case <-ctx.Done():
case <-termCh:
}
log.Info("shutdown signal received")
cancel()
}()
}
@@ -287,7 +274,7 @@ func getSetupKeyFromFile(setupKeyPath string) (string, error) {
func handleRebrand(cmd *cobra.Command) error {
var err error
if logFile == defaultLogFile {
if slices.Contains(logFiles, defaultLogFile) {
if migrateToNetbird(oldDefaultLogFile, defaultLogFile) {
cmd.Printf("will copy Log dir %s and its content to %s\n", oldDefaultLogFileDir, defaultLogFileDir)
err = cpDir(oldDefaultLogFileDir, defaultLogFileDir)
@@ -296,15 +283,14 @@ func handleRebrand(cmd *cobra.Command) error {
}
}
}
if configPath == defaultConfigPath {
if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir)
err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir)
if err != nil {
return err
}
if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir)
err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,12 +1,15 @@
//go:build !ios && !android
package cmd
import (
"context"
"fmt"
"runtime"
"strings"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
@@ -14,6 +17,16 @@ import (
"github.com/netbirdio/netbird/client/server"
)
var serviceCmd = &cobra.Command{
Use: "service",
Short: "manages Netbird service",
}
var (
serviceName string
serviceEnvVars []string
)
type program struct {
ctx context.Context
cancel context.CancelFunc
@@ -22,12 +35,32 @@ type program struct {
serverInstanceMu sync.Mutex
}
func init() {
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
rootCmd.AddCommand(serviceCmd)
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
ctx = internal.CtxInitState(ctx)
return &program{ctx: ctx, cancel: cancel}
}
func newSVCConfig() *service.Config {
func newSVCConfig() (*service.Config, error) {
config := &service.Config{
Name: serviceName,
DisplayName: "Netbird",
@@ -36,23 +69,47 @@ func newSVCConfig() *service.Config {
EnvVars: make(map[string]string),
}
if len(serviceEnvVars) > 0 {
extraEnvs, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
return nil, fmt.Errorf("parse service environment variables: %w", err)
}
config.EnvVars = extraEnvs
}
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
return config, nil
}
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
s, err := service.New(prg, conf)
if err != nil {
log.Fatal(err)
return nil, err
}
return s, nil
return service.New(prg, conf)
}
var serviceCmd = &cobra.Command{
Use: "service",
Short: "manages Netbird service",
func parseServiceEnvVars(envVars []string) (map[string]string, error) {
envMap := make(map[string]string)
for _, env := range envVars {
if env == "" {
continue
}
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("empty environment variable key in: %s", env)
}
envMap[key] = value
}
return envMap, nil
}

View File

@@ -1,3 +1,5 @@
//go:build !ios && !android
package cmd
import (
@@ -47,20 +49,19 @@ func (p *program) Start(svc service.Service) error {
listen, err := net.Listen(split[0], split[1])
if err != nil {
return fmt.Errorf("failed to listen daemon interface: %w", err)
return fmt.Errorf("listen daemon interface: %w", err)
}
go func() {
defer listen.Close()
if split[0] == "unix" {
err = os.Chmod(split[1], 0666)
if err != nil {
if err := os.Chmod(split[1], 0666); err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
}
serverInstance := server.New(p.ctx, configPath, logFile)
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), profilesDisabled)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}
@@ -100,37 +101,49 @@ func (p *program) Stop(srv service.Service) error {
return nil
}
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())
if err := handleRebrand(cmd); err != nil {
return nil, err
}
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
cfg, err := newSVCConfig()
if err != nil {
return nil, fmt.Errorf("create service config: %w", err)
}
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return nil, err
}
return s, nil
}
var runCmd = &cobra.Command{
Use: "run",
Short: "runs Netbird as service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
err = s.Run()
if err != nil {
return err
}
return nil
return s.Run()
},
}
@@ -138,31 +151,14 @@ var startCmd = &cobra.Command{
Use: "start",
Short: "starts Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
cmd.PrintErrln(err)
return err
}
err = s.Start()
if err != nil {
cmd.PrintErrln(err)
return err
if err := s.Start(); err != nil {
return fmt.Errorf("start service: %w", err)
}
cmd.Println("Netbird service has been started")
return nil
@@ -173,29 +169,14 @@ var stopCmd = &cobra.Command{
Use: "stop",
Short: "stops Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
err = s.Stop()
if err != nil {
return err
if err := s.Stop(); err != nil {
return fmt.Errorf("stop service: %w", err)
}
cmd.Println("Netbird service has been stopped")
return nil
@@ -206,31 +187,48 @@ var restartCmd = &cobra.Command{
Use: "restart",
Short: "restarts Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
err = s.Restart()
if err != nil {
return err
if err := s.Restart(); err != nil {
return fmt.Errorf("restart service: %w", err)
}
cmd.Println("Netbird service has been restarted")
return nil
},
}
var svcStatusCmd = &cobra.Command{
Use: "status",
Short: "shows Netbird service status",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
status, err := s.Status()
if err != nil {
return fmt.Errorf("get service status: %w", err)
}
var statusText string
switch status {
case service.StatusRunning:
statusText = "Running"
case service.StatusStopped:
statusText = "Stopped"
case service.StatusUnknown:
statusText = "Unknown"
default:
statusText = fmt.Sprintf("Unknown (%d)", status)
}
cmd.Printf("Netbird service status: %s\n", statusText)
return nil
},
}

View File

@@ -1,87 +1,121 @@
//go:build !ios && !android
package cmd
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"github.com/kardianos/service"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/util"
)
var ErrGetServiceStatus = fmt.Errorf("failed to get service status")
// Common service command setup
func setupServiceCommand(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())
return handleRebrand(cmd)
}
// Build service arguments for install/reconfigure
func buildServiceArguments() []string {
args := []string{
"service",
"run",
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {
args = append(args, "--management-url", managementURL)
}
for _, logFile := range logFiles {
args = append(args, "--log-file", logFile)
}
return args
}
// Configure platform-specific service settings
func configurePlatformSpecificSettings(svcConfig *service.Config) error {
if runtime.GOOS == "linux" {
// Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile := util.FindFirstLogPath(logFiles); logFile != "" {
setStdLogPath := true
dir := filepath.Dir(logFile)
if _, err := os.Stat(dir); err != nil {
if err = os.MkdirAll(dir, 0750); err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
return nil
}
// Create fully configured service config for install/reconfigure
func createServiceConfigForInstall() (*service.Config, error) {
svcConfig, err := newSVCConfig()
if err != nil {
return nil, fmt.Errorf("create service config: %w", err)
}
svcConfig.Arguments = buildServiceArguments()
if err = configurePlatformSpecificSettings(svcConfig); err != nil {
return nil, fmt.Errorf("configure platform-specific settings: %w", err)
}
return svcConfig, nil
}
var installCmd = &cobra.Command{
Use: "install",
Short: "installs Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
if err := setupServiceCommand(cmd); err != nil {
return err
}
svcConfig := newSVCConfig()
svcConfig.Arguments = []string{
"service",
"run",
"--config",
configPath,
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
}
if logFile != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
}
if runtime.GOOS == "linux" {
// Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile != "console" {
setStdLogPath := true
dir := filepath.Dir(logFile)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0750)
if err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil {
cmd.PrintErrln(err)
return err
}
err = s.Install()
if err != nil {
cmd.PrintErrln(err)
return err
if err := s.Install(); err != nil {
return fmt.Errorf("install service: %w", err)
}
cmd.Println("Netbird service has been installed")
@@ -93,27 +127,109 @@ var uninstallCmd = &cobra.Command{
Use: "uninstall",
Short: "uninstalls Netbird service from system",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := setupServiceCommand(cmd); err != nil {
return err
}
cmd.SetOut(cmd.OutOrStdout())
cfg, err := newSVCConfig()
if err != nil {
return fmt.Errorf("create service config: %w", err)
}
err := handleRebrand(cmd)
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return err
}
if err := s.Uninstall(); err != nil {
return fmt.Errorf("uninstall service: %w", err)
}
cmd.Println("Netbird service has been uninstalled")
return nil
},
}
var reconfigureCmd = &cobra.Command{
Use: "reconfigure",
Short: "reconfigures Netbird service with new settings",
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install.
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil {
return err
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil {
return err
return fmt.Errorf("create service: %w", err)
}
err = s.Uninstall()
if err != nil {
return err
if wasRunning {
cmd.Println("Stopping Netbird service...")
if err := s.Stop(); err != nil {
cmd.Printf("Warning: failed to stop service: %v\n", err)
}
}
cmd.Println("Netbird service has been uninstalled")
cmd.Println("Removing existing service configuration...")
if err := s.Uninstall(); err != nil {
return fmt.Errorf("uninstall existing service: %w", err)
}
cmd.Println("Installing service with new configuration...")
if err := s.Install(); err != nil {
return fmt.Errorf("install service with new config: %w", err)
}
if wasRunning {
cmd.Println("Starting Netbird service...")
if err := s.Start(); err != nil {
return fmt.Errorf("start service after reconfigure: %w", err)
}
cmd.Println("Netbird service has been reconfigured and started")
} else {
cmd.Println("Netbird service has been reconfigured")
}
return nil
},
}
func isServiceRunning() (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return false, err
}
status, err := s.Status()
if err != nil {
return false, fmt.Errorf("%w: %w", ErrGetServiceStatus, err)
}
return status == service.StatusRunning, nil
}

263
client/cmd/service_test.go Normal file
View File

@@ -0,0 +1,263 @@
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) {
tests := []struct {
name string
envVars []string
expected map[string]string
expectErr bool
}{
{
name: "Valid single env var",
envVars: []string{"LOG_LEVEL=debug"},
expected: map[string]string{
"LOG_LEVEL": "debug",
},
},
{
name: "Valid multiple env vars",
envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"},
expected: map[string]string{
"LOG_LEVEL": "debug",
"CUSTOM_VAR": "value",
},
},
{
name: "Env var with spaces",
envVars: []string{" KEY = value "},
expected: map[string]string{
"KEY": "value",
},
},
{
name: "Invalid format - no equals",
envVars: []string{"INVALID"},
expectErr: true,
},
{
name: "Invalid format - empty key",
envVars: []string{"=value"},
expectErr: true,
},
{
name: "Empty value is valid",
envVars: []string{"KEY="},
expected: map[string]string{
"KEY": "",
},
},
{
name: "Empty slice",
envVars: []string{},
expected: map[string]string{},
},
{
name: "Empty string in slice",
envVars: []string{"", "KEY=value", ""},
expected: map[string]string{"KEY": "value"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseServiceEnvVars(tt.envVars)
if tt.expectErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// TestServiceConfigWithEnvVars tests service config creation with env vars
func TestServiceConfigWithEnvVars(t *testing.T) {
originalServiceName := serviceName
originalServiceEnvVars := serviceEnvVars
defer func() {
serviceName = originalServiceName
serviceEnvVars = originalServiceEnvVars
}()
serviceName = "test-service"
serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"}
cfg, err := newSVCConfig()
require.NoError(t, err)
assert.Equal(t, "test-service", cfg.Name)
assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"])
assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"])
if runtime.GOOS == "linux" {
assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"])
}
}

View File

@@ -12,14 +12,15 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
port int
user = "root"
host string
port int
userName = "root"
host string
)
var sshCmd = &cobra.Command{
@@ -31,7 +32,7 @@ var sshCmd = &cobra.Command{
split := strings.Split(args[0], "@")
if len(split) == 2 {
user = split[0]
userName = split[0]
host = split[1]
} else {
host = args[0]
@@ -46,7 +47,7 @@ var sshCmd = &cobra.Command{
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
@@ -58,11 +59,19 @@ var sshCmd = &cobra.Command{
ctx := internal.CtxInitState(cmd.Context())
config, err := internal.UpdateConfig(internal.ConfigInput{
ConfigPath: configPath,
})
pm := profilemanager.NewProfileManager()
activeProf, err := pm.GetActiveProfile()
if err != nil {
return err
return fmt.Errorf("get active profile: %v", err)
}
profPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile path: %v", err)
}
config, err := profilemanager.ReadConfig(profPath)
if err != nil {
return fmt.Errorf("read profile config: %v", err)
}
sig := make(chan os.Signal, 1)
@@ -89,7 +98,7 @@ var sshCmd = &cobra.Command{
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +

View File

@@ -11,6 +11,7 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
@@ -26,6 +27,7 @@ var (
statusFilter string
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
)
var statusCmd = &cobra.Command{
@@ -45,6 +47,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -57,7 +60,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err
}
err = util.InitLog(logLevel, "console")
err = util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
@@ -89,7 +92,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:
@@ -120,7 +129,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -156,6 +165,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag()
}
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil
}

View File

@@ -103,13 +103,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}
@@ -124,7 +124,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
}
func startClientDaemon(
t *testing.T, ctx context.Context, _, configPath string,
t *testing.T, ctx context.Context, _, _ string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
@@ -134,7 +134,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
configPath, "")
"", false)
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"os/user"
"runtime"
"strings"
"time"
@@ -12,12 +13,14 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
@@ -35,6 +38,9 @@ const (
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
)
var (
@@ -42,6 +48,8 @@ var (
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
profileName string
configPath string
upCmd = &cobra.Command{
Use: "up",
@@ -70,6 +78,8 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
}
@@ -79,7 +89,7 @@ func upFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
@@ -101,13 +111,41 @@ func upFunc(cmd *cobra.Command, args []string) error {
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
if foregroundMode {
return runInForegroundMode(ctx, cmd)
pm := profilemanager.NewProfileManager()
username, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
return runInDaemonMode(ctx, cmd)
var profileSwitched bool
// switch profile if provided
if profileName != "" {
err = switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true
}
activeProf, err := pm.GetActiveProfile()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
if foregroundMode {
return runInForegroundMode(ctx, cmd, activeProf)
}
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
err := handleRebrand(cmd)
if err != nil {
return err
@@ -118,7 +156,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
ic, err := setupConfig(customDNSAddressConverted, cmd)
configFilePath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile file path: %v", err)
}
ic, err := setupConfig(customDNSAddressConverted, cmd, configFilePath)
if err != nil {
return fmt.Errorf("setup config: %v", err)
}
@@ -128,12 +171,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
config, err := internal.UpdateOrCreateConfig(*ic)
config, err := profilemanager.UpdateOrCreateConfig(*ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
@@ -153,10 +196,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
return fmt.Errorf("parse custom DNS address: %v", err)
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -181,10 +224,41 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
}
if status.Status == string(internal.StatusConnected) {
cmd.Println("Already connected")
return nil
if !profileSwitched {
cmd.Println("Already connected")
return nil
}
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
log.Errorf("call service down method: %v", err)
return err
}
}
username, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon")
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
}
if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil {
return fmt.Errorf("daemon up failed: %v", err)
}
cmd.Println("Connected")
return nil
}
func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, customDNSAddressConverted []byte, username string) error {
providedSetupKey, err := getSetupKey()
if err != nil {
return fmt.Errorf("get setup key: %v", err)
@@ -195,6 +269,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return fmt.Errorf("setup login request: %v", err)
}
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
var loginErr error
var loginResp *proto.LoginResponse
@@ -219,27 +296,105 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil {
return fmt.Errorf("sso login failed: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &activeProf.Name,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, profileName, username string) *proto.SetConfigRequest {
var req proto.SetConfigRequest
req.ProfileName = profileName
req.Username = username
req.ManagementUrl = managementURL
req.AdminURL = adminURL
req.NatExternalIPs = natExternalIPs
req.CustomDNSAddress = customDNSAddressConverted
req.ExtraIFaceBlacklist = extraIFaceBlackList
req.DnsLabels = dnsLabelsValidated.ToPunycodeList()
req.CleanDNSLabels = dnsLabels != nil && len(dnsLabels) == 0
req.CleanNATExternalIPs = natExternalIPs != nil && len(natExternalIPs) == 0
if cmd.Flag(enableRosenpassFlag).Changed {
req.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
req.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
return nil
}
req.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int64(wireguardPort)
req.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
req.OptionalPreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
req.DisableAutoConnect = &autoConnectDisabled
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
req.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
if cmd.Flag(disableClientRoutesFlag).Changed {
req.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
req.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
req.DisableDns = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
req.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
req.BlockLanAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
req.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled
}
return &req
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFilePath string) (*profilemanager.ConfigInput, error) {
ic := profilemanager.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
ConfigPath: configFilePath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
@@ -325,7 +480,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
@@ -484,7 +638,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
if !isValidAddrPort(customDNSAddress) {
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
}
if customDNSAddress == "" && logFile != "console" {
if customDNSAddress == "" && util.FindFirstLogPath(logFiles) != "" {
parsed = []byte("empty")
} else {
parsed = []byte(customDNSAddress)

View File

@@ -3,18 +3,55 @@ package cmd
import (
"context"
"os"
"os/user"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
var cliAddr string
func TestUpDaemon(t *testing.T) {
mgmAddr := startTestingServices(t)
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.ConfigDirOverride = tempDir
currUser, err := user.Current()
if err != nil {
t.Fatalf("failed to get current user: %v", err)
return
}
sm := profilemanager.ServiceManager{}
err = sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test1",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
return
}
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.ConfigDirOverride = ""
})
mgmAddr := startTestingServices(t)
confPath := tempDir + "/config.json"
ctx := internal.CtxInitState(context.Background())

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -26,7 +27,7 @@ var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
config *profilemanager.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
@@ -88,9 +89,9 @@ func New(opts Options) (*Client, error) {
}
t := true
var config *internal.Config
var config *profilemanager.Config
var err error
input := internal.ConfigInput{
input := profilemanager.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
@@ -98,9 +99,9 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
config, err = profilemanager.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
config, err = profilemanager.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)

View File

@@ -221,7 +221,7 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -243,7 +243,7 @@ func (t *ICMPTracker) track(
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
@@ -294,7 +294,7 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}

View File

@@ -211,7 +211,7 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.logger.Trace2("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
@@ -240,7 +240,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
@@ -262,7 +262,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, p
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
@@ -340,17 +340,17 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, p
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
@@ -438,7 +438,7 @@ func (t *TCPTracker) cleanup() {
if conn.timeoutExceeded(timeout) {
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change

View File

@@ -116,7 +116,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.logger.Trace2("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}

View File

@@ -104,6 +104,12 @@ type Manager struct {
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
}
// decoder for packages
@@ -189,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
}
m.routingEnabled.Store(false)
@@ -519,22 +526,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -581,14 +572,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
// FilterOutBound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
// FilterInbound filters incoming packets
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
return m.filterInbound(packetData, size)
}
// UpdateLocalIPs updates the list of local IPs
@@ -596,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -610,7 +601,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
return false
}
@@ -618,8 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true
}
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
return false
}
@@ -723,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false
}
// dropFilter implements filtering logic for incoming packets.
// filterInbound implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, size int) bool {
func (m *Manager) filterInbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -736,19 +727,26 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
return true
}
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false
}
@@ -768,7 +766,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
@@ -809,7 +807,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
}
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
m.logger.Error1("Failed to inject local packet: %v", err)
}
// don't process this packet further
@@ -821,7 +819,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
@@ -837,7 +835,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
@@ -865,7 +863,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err)
m.logger.Error1("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
}
}
@@ -903,7 +901,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace("couldn't decode packet, err: %s", err)
m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false
}

View File

@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.filterInbound(inbound, 0)
}
})
}
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
}
// Test packet
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut, 0)
manager.filterOutbound(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0)
manager.filterInbound(testIn, 0)
}
})
}
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.filterInbound(inbound, 0)
}
})
}
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0)
manager.filterOutbound(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.filterOutbound(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.filterInbound(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.filterOutbound(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.filterInbound(inbound, 0)
}
})
}
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.filterOutbound(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.filterInbound(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.filterOutbound(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.filterOutbound(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
}
})
}
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
// Data transfer
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
// Connection teardown
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
}
})
}
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.filterOutbound(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.filterInbound(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.filterOutbound(ack, 0)
}
// Pre-generate test packets
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
manager.filterOutbound(outPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
}
})
})
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
}
})
})

View File

@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)

View File

@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), 0) {
if m.filterInbound(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0)
result := manager.filterOutbound(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
result = manager.filterOutbound(buf.Bytes(), 0)
require.False(t, result)
}
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0)
drop = manager.filterInbound(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@@ -57,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
e.logger.Error1("CreateOutboundPacket: %v", err)
continue
}
written++

View File

@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
}()
@@ -52,11 +52,11 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
@@ -72,7 +72,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
@@ -80,7 +80,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
f.logger.Error1("forwarder: Failed to read ICMP response: %v", err)
}
return 0
}
@@ -101,12 +101,12 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)

View File

@@ -38,7 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
return
}
@@ -47,9 +47,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
f.logger.Error1("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
@@ -61,7 +61,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
@@ -75,10 +75,10 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
@@ -127,7 +127,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}

View File

@@ -78,10 +78,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP conn close error for %v: %v", epID(id), err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
conn.ep.Close()
@@ -112,10 +112,10 @@ func (f *udpForwarder) cleanup() {
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
f.logger.Debug2("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
}
idle.conn.ep.Close()
@@ -124,7 +124,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
@@ -143,7 +143,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return
}
@@ -160,7 +160,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
}
@@ -169,9 +169,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
f.logger.Debug1("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
@@ -194,10 +194,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
@@ -205,7 +205,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
@@ -220,10 +220,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
pConn.cancel()
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
f.logger.Error2("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
f.logger.Error2("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
}
var rxPackets, txPackets uint64
@@ -263,7 +263,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)

View File

@@ -44,7 +44,12 @@ var levelStrings = map[Level]string{
type logMessage struct {
level Level
format string
args []any
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
}
// Logger is a high-performance, non-blocking logger
@@ -89,62 +94,198 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
}
}
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelError, format: format}:
default:
}
}
}
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
func (l *Logger) Warn(format string) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format}:
default:
}
}
}
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
func (l *Logger) Info(format string) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelInfo, format: format}:
default:
}
}
}
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
func (l *Logger) Debug(format string) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format}:
default:
}
}
}
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
func (l *Logger) Trace(format string) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
default:
}
}
}
func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
default:
}
}
}
func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
default:
}
}
}
func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default:
}
}
}
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
}
}
}
}
}
}
*buf = append(*buf, msg...)
var formatted string
switch argCount {
case 0:
formatted = msg.format
case 1:
formatted = fmt.Sprintf(msg.format, msg.arg1)
case 2:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2)
case 3:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3)
case 4:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4)
case 5:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
}
*buf = append(*buf, formatted...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
@@ -157,7 +298,7 @@ func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
l.formatMessage(bufp, msg)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
@@ -249,4 +390,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done:
return nil
}
}
}

View File

@@ -19,22 +19,17 @@ func (d *discard) Write(p []byte) (n int, err error) {
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
@@ -52,7 +47,7 @@ func BenchmarkLogger(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state)
}
})
@@ -62,7 +57,7 @@ func BenchmarkLogger(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
logger.Trace6("Complex trace: proto=%s dir=%s flags=%d seq=%d ack=%d size=%d", protocol, direction, flags, sequence, acknowledged, 1460)
}
})
}
@@ -72,7 +67,6 @@ func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
@@ -82,7 +76,7 @@ func BenchmarkLoggerParallel(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state)
}
})
}
@@ -92,7 +86,6 @@ func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
@@ -102,7 +95,7 @@ func BenchmarkLoggerBurst(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state)
}
}
}

View File

@@ -0,0 +1,408 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@@ -0,0 +1,416 @@
package uspfilter
import (
"fmt"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
)
// BenchmarkDNATTranslation measures the performance of DNAT operations
func BenchmarkDNATTranslation(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
description string
}{
{
name: "tcp_with_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: true,
description: "TCP packet with DNAT translation enabled",
},
{
name: "tcp_without_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
description: "TCP packet without DNAT (baseline)",
},
{
name: "udp_with_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: true,
description: "UDP packet with DNAT translation enabled",
},
{
name: "udp_without_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
description: "UDP packet without DNAT (baseline)",
},
{
name: "icmp_with_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: true,
description: "ICMP packet with DNAT translation enabled",
},
{
name: "icmp_without_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: false,
description: "ICMP packet without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mapping if needed
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
if sc.setupDNAT {
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Create test packets
srcIP := netip.MustParseAddr("172.16.0.1")
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
// Pre-establish connection for reverse DNAT test
if sc.setupDNAT {
manager.filterOutbound(outboundPacket, 0)
}
b.ResetTimer()
// Benchmark outbound DNAT translation
b.Run("outbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
// Benchmark inbound reverse DNAT translation
if sc.setupDNAT {
b.Run("inbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
manager.filterInbound(packet, 0)
}
})
}
})
}
}
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup multiple DNAT mappings
numMappings := 100
originalIPs := make([]netip.Addr, numMappings)
translatedIPs := make([]netip.Addr, numMappings)
for i := 0; i < numMappings; i++ {
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
require.NoError(b, err)
}
srcIP := netip.MustParseAddr("172.16.0.1")
// Pre-generate packets
outboundPackets := make([][]byte, numMappings)
inboundPackets := make([][]byte, numMappings)
for i := 0; i < numMappings; i++ {
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
// Establish connections
manager.filterOutbound(outboundPackets[i], 0)
}
b.ResetTimer()
b.Run("concurrent_outbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
i++
}
})
})
b.Run("concurrent_inbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
manager.filterInbound(packet, 0)
i++
}
})
})
}
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
func BenchmarkDNATScaling(b *testing.B) {
mappingCounts := []int{1, 10, 100, 1000}
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mappings
for i := 0; i < count; i++ {
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Test with the last mapping added (worst case for lookup)
srcIP := netip.MustParseAddr("172.16.0.1")
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
}
}
// generateDNATTestPacket creates a test packet for DNAT benchmarking
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
tb.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: proto,
}
var transportLayer gopacket.SerializableLayer
switch proto {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
case layers.IPProtocolICMPv4:
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
transportLayer = icmp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(tb, err)
return buf.Bytes()
}
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
func BenchmarkChecksumUpdate(b *testing.B) {
// Create test data for checksum calculations
testData := make([]byte, 64) // Typical packet size for checksum testing
for i := range testData {
testData[i] = byte(i)
}
b.Run("ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
}
})
b.Run("icmp_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = icmpChecksum(testData)
}
})
b.Run("incremental_update", func(b *testing.B) {
oldBytes := []byte{192, 168, 1, 100}
newBytes := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
}
})
}
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Create fresh packet each time to isolate allocation testing
testPacket := make([]byte, len(packet))
copy(testPacket, packet)
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
}
}
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
func BenchmarkDirectIPExtraction(b *testing.B) {
// Create a test packet
srcIP := netip.MustParseAddr("172.16.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
b.Run("direct_byte_access", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Direct extraction from packet bytes
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
}
})
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
// Extract using decoder (traditional method)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
_ = dst
}
})
}
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
func BenchmarkChecksumOptimizations(b *testing.B) {
// Create test IPv4 header (20 bytes)
header := make([]byte, 20)
for i := range header {
header[i] = byte(i)
}
// Clear checksum field
header[10] = 0
header[11] = 0
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(header)
}
})
// Test incremental checksum updates
oldIP := []byte{192, 168, 1, 100}
newIP := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
b.Run("optimized_incremental_update", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
}
})
}

View File

@@ -0,0 +1,145 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
// Add DNAT mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcPort uint16
dstPort uint16
}{
{"TCP", layers.IPProtocolTCP, 12345, 80},
{"UDP", layers.IPProtocolUDP, 12345, 53},
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test outbound DNAT translation
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound packet (should translate destination)
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, translated, "Outbound packet should be translated")
// Verify destination IP was changed
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
// Test inbound reverse DNAT translation
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet (should reverse translate source)
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, reversed, "Inbound packet should be reverse translated")
// Verify source IP was changed back to original
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
// Test that checksums are recalculated correctly
if tc.protocol != layers.IPProtocolICMPv4 {
// For TCP/UDP, verify the transport checksum was updated
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
}
})
}
}
// parsePacket helper to create a decoder for testing
func parsePacket(t testing.TB, packetData []byte) *decoder {
t.Helper()
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
// Test adding mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
// Verify mapping exists
result, exists := manager.getDNATTranslation(originalIP)
require.True(t, exists)
require.Equal(t, translatedIP, result)
// Test reverse lookup
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
require.True(t, exists)
require.Equal(t, originalIP, reverseResult)
// Test removing mapping
err = manager.RemoveInternalDNATMapping(originalIP)
require.NoError(t, err)
// Verify mapping no longer exists
_, exists = manager.getDNATTranslation(originalIP)
require.False(t, exists)
_, exists = manager.findReverseDNATMapping(translatedIP)
require.False(t, exists)
// Test error cases
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
require.Error(t, err, "Should reject invalid original IP")
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
require.Error(t, err, "Should reject invalid translated IP")
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}

View File

@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
dropped := m.filterOutbound(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -0,0 +1,96 @@
package bind
import (
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/monotime"
)
const (
saveFrequency = int64(5 * time.Second)
)
type PeerRecord struct {
Address netip.AddrPort
LastActivity atomic.Int64 // UnixNano timestamp
}
type ActivityRecorder struct {
mu sync.RWMutex
peers map[string]*PeerRecord // publicKey to PeerRecord map
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
}
func NewActivityRecorder() *ActivityRecorder {
return &ActivityRecorder{
peers: make(map[string]*PeerRecord),
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
}
}
// GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
r.mu.RLock()
defer r.mu.RUnlock()
activities := make(map[string]monotime.Time, len(r.peers))
for key, record := range r.peers {
monoTime := record.LastActivity.Load()
activities[key] = monotime.Time(monoTime)
}
return activities
}
// UpsertAddress adds or updates the address for a publicKey
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
r.mu.Lock()
defer r.mu.Unlock()
var record *PeerRecord
record, exists := r.peers[publicKey]
if exists {
delete(r.addrToPeer, record.Address)
record.Address = address
} else {
record = &PeerRecord{
Address: address,
}
record.LastActivity.Store(int64(monotime.Now()))
r.peers[publicKey] = record
}
r.addrToPeer[address] = record
}
func (r *ActivityRecorder) Remove(publicKey string) {
r.mu.Lock()
defer r.mu.Unlock()
if record, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, record.Address)
delete(r.peers, publicKey)
}
}
// record updates LastActivity for the given address using atomic store
func (r *ActivityRecorder) record(address netip.AddrPort) {
r.mu.RLock()
record, ok := r.addrToPeer[address]
r.mu.RUnlock()
if !ok {
log.Warnf("could not find record for address %s", address)
return
}
now := int64(monotime.Now())
last := record.LastActivity.Load()
if now-last < saveFrequency {
return
}
_ = record.LastActivity.CompareAndSwap(last, now)
}

View File

@@ -0,0 +1,25 @@
package bind
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/monotime"
)
func TestActivityRecorder_GetLastActivities(t *testing.T) {
peer := "peer1"
ar := NewActivityRecorder()
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
activities := ar.GetLastActivities()
p, ok := activities[peer]
if !ok {
t.Fatalf("Expected activity for peer %s, but got none", peer)
}
if monotime.Since(p) > 5*time.Second {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
}
}

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -1,6 +1,7 @@
package bind
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -15,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type RecvMessage struct {
@@ -51,22 +53,24 @@ type ICEBind struct {
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
activityRecorder *ActivityRecorder
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
activityRecorder: NewActivityRecorder(),
}
rc := receiverCreator{
@@ -100,6 +104,10 @@ func (s *ICEBind) Close() error {
return s.StdNetBind.Close()
}
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
@@ -146,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
@@ -199,6 +207,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
}
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
@@ -257,6 +270,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
if isTransportPkg(buffs, sizes[0]) {
if ep, ok := eps[0].(*Endpoint); ok {
c.activityRecorder.record(ep.AddrPort)
}
}
return 1, nil
}
}
@@ -272,3 +292,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
}
msgsPool.Put(msgs)
}
func isTransportPkg(buffers [][]byte, n int) bool {
// The first buffer should contain at least 4 bytes for type
if len(buffers[0]) < 4 {
return true
}
// WireGuard packet type is a little-endian uint32 at start
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
// Check if packetType matches known WireGuard message types
if packetType == 4 && n > 32 {
return true
}
return false
}

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
var allAddresses []string
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
delete(m.addressMap, addr)
}
allAddresses = append(allAddresses, addresses...)
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
}
}
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
m.addressMapMu.RUnlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,22 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
return
}
// Userspace mode: UDPConn wrapper around nbnet.PacketConn
if wrapped, ok := m.params.UDPConn.(*UDPConn); ok {
if conn, ok := wrapped.GetPacketConn().(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
}
}
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
}
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type UDPConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
// GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr)
}
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) performFilterCheck(addr net.Addr) error {
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
@@ -276,3 +278,7 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}

View File

@@ -16,6 +16,8 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -36,16 +38,18 @@ const (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
wgCfg.startUAPI()
return wgCfg
@@ -87,7 +91,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return ipcErr
}
if endpoint != nil {
addr, err := netip.ParseAddr(endpoint.IP.String())
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
@@ -104,7 +120,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
c.activityRecorder.Remove(peerKey)
return ipcErr
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
@@ -205,6 +224,10 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
return parseStatus(c.deviceName, ipcStr)
}
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
return c.activityRecorder.GetLastActivities()
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error
@@ -507,7 +530,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
if currentPeer == nil {
continue
}
if val != "" {
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
currentPeer.PresharedKey = true
}
}

View File

@@ -79,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -9,11 +9,11 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
// FilterOutbound filter outgoing packets from host to external destinations
FilterOutbound(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) {
if !filter.FilterInbound(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -99,8 +100,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if err != nil {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
UDPConn: udpConn,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,

View File

@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()

View File

@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/monotime"
)
type WGConfigurer interface {
@@ -19,4 +20,5 @@ type WGConfigurer interface {
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
)
const (
@@ -29,6 +30,11 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
var (
// ErrIfaceNotFound is returned when the WireGuard interface is not found
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -117,6 +123,9 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
@@ -126,6 +135,9 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.RemovePeer(peerKey)
@@ -135,6 +147,9 @@ func (w *WGIface) RemovePeer(peerKey string) error {
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP)
@@ -144,6 +159,9 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
@@ -214,10 +232,29 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
// GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.GetStats()
}
func (w *WGIface) LastActivities() map[string]monotime.Time {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return nil
}
return w.configurer.LastActivities()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.FullStats()
}

View File

@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.

View File

@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.

View File

@@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
var skipProxy bool
if val := os.Getenv(EnvSkipProxy); val != "" {
skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
}
if skipProxy {
return nsTunDev, tunNet, nil

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type ProxyBind struct {
@@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
}
// AddTurnConn adds a new connection to the bind.
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
}
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
@@ -151,7 +171,7 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
return nil, fmt.Errorf("parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))

View File

@@ -11,6 +11,8 @@ import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
@@ -77,8 +92,10 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
return fmt.Errorf("close remote conn: %w", err)
}
return nil
}
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil {
return 0, ctx.Err()
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}

View File

@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
func (w *KernelFactory) Free() error {

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
}
func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{
Bind: w.bind,
}
return proxyBind.NewProxyBind(w.bind)
}
func (w *USPFactory) Free() error {

View File

@@ -0,0 +1,32 @@
package listener
import "sync"
type CloseListener struct {
listener func()
mu sync.Mutex
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.mu.Lock()
defer c.mu.Unlock()
c.listener = listener
}
func (c *CloseListener) Notify() {
c.mu.Lock()
if c.listener == nil {
c.mu.Unlock()
return
}
listener := c.listener
c.mu.Unlock()
listener()
}

View File

@@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
SetDisconnectListener(disconnected func())
}

View File

@@ -17,7 +17,7 @@ import (
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
_ = util.InitLog("trace", util.LogConsole)
code := m.Run()
os.Exit(code)
}
@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// WGUDPProxy proxies
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{
localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
}
return p
}
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr
}
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() {
if p.remoteConn == nil {
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
@@ -172,6 +183,11 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
for {
n, err := p.remoteConnRead(ctx, buf)
if err != nil {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
return
}

View File

@@ -11,6 +11,7 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
@@ -48,6 +49,7 @@ type TokenInfo struct {
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
Email string `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
@@ -64,7 +66,7 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
@@ -80,7 +82,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopCli
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
@@ -89,7 +91,7 @@ func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAu
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {

View File

@@ -6,6 +6,7 @@ import (
"crypto/subtle"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"html/template"
@@ -230,9 +231,46 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
if err != nil {
log.Warnf("failed to parse email from ID token: %v", err)
} else {
tokenInfo.Email = email
}
return tokenInfo, nil
}
func parseEmailFromIDToken(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return "", fmt.Errorf("invalid token format")
}
data, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("failed to decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(data, &claims); err != nil {
return "", fmt.Errorf("json unmarshal error: %w", err)
}
var email string
if emailValue, ok := claims["email"].(string); ok {
email = emailValue
} else {
val, ok := claims["name"].(string)
if ok {
email = val
} else {
return "", fmt.Errorf("email or name field not found in token payload")
}
}
return email, nil
}
func createCodeChallenge(codeVerifier string) string {
sha2 := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(sha2[:])

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
@@ -26,11 +25,11 @@ import (
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
enabledLocally bool
rosenpassEnabled bool
lazyConnMgr *manager.Manager
@@ -39,12 +38,12 @@ type ConnMgr struct {
lazyCtxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
rosenpassEnabled: engineConfig.RosenpassEnabled,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
@@ -64,6 +63,11 @@ func (e *ConnMgr) Start(ctx context.Context) {
return
}
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
@@ -83,7 +87,12 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
return nil
}
log.Warnf("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager()
@@ -133,7 +142,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
@@ -175,7 +184,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
@@ -201,7 +210,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
if !ok {
return
}
defer conn.Close()
defer conn.Close(false)
if !e.isStartedWithLazyMgr() {
return
@@ -211,23 +220,27 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return conn, true
return
}
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
// If locally the lazy connection is disabled, we force the peer connection open.
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
}
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
}
func (e *ConnMgr) Close() {
@@ -244,7 +257,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
@@ -275,7 +288,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {

View File

@@ -17,11 +17,11 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
@@ -38,7 +38,7 @@ import (
type ConnectClient struct {
ctx context.Context
config *Config
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
@@ -48,7 +48,7 @@ type ConnectClient struct {
func NewConnectClient(
ctx context.Context,
config *Config,
config *profilemanager.Config,
statusRecorder *peer.Status,
) *ConnectClient {
@@ -414,7 +414,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -484,7 +484,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
}
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
@@ -526,17 +526,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
addr := net.UDPAddr{Port: initPort}
conn, err := net.ListenUDP("udp", &addr)
if err == nil {
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
closeConnWithLog(conn)
return initPort, nil
return returnPort, nil
}
// if the port is already in use, ask the system for a free port

View File

@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
shouldMatch bool
}{
{
name: "not provided, fallback to default",
name: "when port is 0 use random port",
port: 0,
want: 51820,
shouldMatch: true,
want: 0,
shouldMatch: false,
},
{
name: "provided and available",
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
shouldMatch: false,
},
}
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
if err != nil {
t.Errorf("freePort error = %v", err)
}
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
_ = c1.Close()
}(c1)
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
tests[1].port++
tests[1].want++
}
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -16,6 +16,7 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -24,10 +25,10 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/profilemanager"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -38,10 +39,12 @@ status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized client log file of the NetBird client.
netbird.err: Most recent, anonymized stderr log file of the NetBird client.
netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided.
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states.
@@ -105,7 +108,29 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
The routes.txt file contains detailed routing table information in a tabular format:
- Destination: Network prefix (IP_ADDRESS/PREFIX_LENGTH)
- Gateway: Next hop IP address (or "-" if direct)
- Interface: Network interface name
- Metric: Route priority/metric (lower values preferred)
- Protocol: Routing protocol (kernel, static, dhcp, etc.)
- Scope: Route scope (global, link, host, etc.)
- Type: Route type (unicast, local, broadcast, etc.)
- Table: Routing table name (main, local, netbird, etc.)
The table format provides a comprehensive view of the system's routing configuration, including information from multiple routing tables on Linux systems. This is valuable for troubleshooting routing issues and understanding traffic flow.
For anonymized routes, IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. Interface names are anonymized using string anonymization.
Resolved Domains
The resolved_domains.txt file contains information about domain names that have been resolved to IP addresses by NetBird's DNS resolver. This includes:
- Original domain patterns that were configured for routing
- Resolved domain names that matched those patterns
- IP address prefixes that were resolved for each domain
- Parent domain associations showing which original pattern each resolved domain belongs to
All domain names and IP addresses in this file follow the same anonymization rules as described above. This information is valuable for troubleshooting DNS resolution and routing issues.
Network Interfaces
The interfaces.txt file contains information about network interfaces, including:
@@ -143,6 +168,22 @@ nftables.txt:
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
IP Rules (Linux only)
The ip_rules.txt file contains detailed IP routing rule information:
- Priority: Rule priority number (lower values processed first)
- From: Source IP prefix or "all" if unspecified
- To: Destination IP prefix or "all" if unspecified
- IIF: Input interface name or "-" if unspecified
- OIF: Output interface name or "-" if unspecified
- Table: Target routing table name (main, local, netbird, etc.)
- Action: Rule action (lookup, goto, blackhole, etc.)
- Mark: Firewall mark value in hex format or "-" if unspecified
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
`
const (
@@ -158,15 +199,15 @@ type BundleGenerator struct {
anonymizer *anonymize.Anonymizer
// deps
internalConfig *internal.Config
internalConfig *profilemanager.Config
statusRecorder *peer.Status
networkMap *mgmProto.NetworkMap
logFile string
// config
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
archive *zip.Writer
}
@@ -175,16 +216,23 @@ type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
type GeneratorDependencies struct {
InternalConfig *internal.Config
InternalConfig *profilemanager.Config
StatusRecorder *peer.Status
NetworkMap *mgmProto.NetworkMap
LogFile string
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
// Default to 1 log file for backward compatibility when 0 is provided
logFileCount := cfg.LogFileCount
if logFileCount == 0 {
logFileCount = 1
}
return &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
@@ -196,6 +244,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
}
@@ -247,7 +296,11 @@ func (g *BundleGenerator) createArchive() error {
}
if err := g.addConfig(); err != nil {
log.Errorf("Failed to add config to debug bundle: %v", err)
log.Errorf("failed to add config to debug bundle: %v", err)
}
if err := g.addResolvedDomains(); err != nil {
log.Errorf("failed to add resolved domains to debug bundle: %v", err)
}
if g.includeSystemInfo {
@@ -255,7 +308,7 @@ func (g *BundleGenerator) createArchive() error {
}
if err := g.addProf(); err != nil {
log.Errorf("Failed to add profiles to debug bundle: %v", err)
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addNetworkMap(); err != nil {
@@ -263,26 +316,26 @@ func (g *BundleGenerator) createArchive() error {
}
if err := g.addStateFile(); err != nil {
log.Errorf("Failed to add state file to debug bundle: %v", err)
log.Errorf("failed to add state file to debug bundle: %v", err)
}
if err := g.addCorruptedStateFiles(); err != nil {
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addWgShow(); err != nil {
log.Errorf("Failed to add wg show output: %v", err)
log.Errorf("failed to add wg show output: %v", err)
}
if g.logFile != "console" && g.logFile != "" {
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if err := g.addLogfile(); err != nil {
log.Errorf("Failed to add log file to debug bundle: %v", err)
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs as fallback: %v", err)
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs: %v", err)
log.Errorf("failed to add systemd logs: %v", err)
}
return nil
@@ -290,15 +343,19 @@ func (g *BundleGenerator) createArchive() error {
func (g *BundleGenerator) addSystemInfo() {
if err := g.addRoutes(); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
log.Errorf("failed to add routes to debug bundle: %v", err)
}
if err := g.addInterfaces(); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
log.Errorf("failed to add interfaces to debug bundle: %v", err)
}
if err := g.addIPRules(); err != nil {
log.Errorf("failed to add IP rules to debug bundle: %v", err)
}
if err := g.addFirewallRules(); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
}
@@ -353,7 +410,6 @@ func (g *BundleGenerator) addConfig() error {
}
}
// Add config content to zip file
configReader := strings.NewReader(configContent.String())
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
return fmt.Errorf("add config file to zip: %w", err)
@@ -365,7 +421,6 @@ func (g *BundleGenerator) addConfig() error {
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
// Add non-sensitive fields
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
if g.internalConfig.NetworkMonitor != nil {
@@ -450,6 +505,27 @@ func (g *BundleGenerator) addInterfaces() error {
return nil
}
func (g *BundleGenerator) addResolvedDomains() error {
if g.statusRecorder == nil {
log.Debugf("skipping resolved domains in debug bundle: no status recorder")
return nil
}
resolvedDomains := g.statusRecorder.GetResolvedDomainsStates()
if len(resolvedDomains) == 0 {
log.Debugf("skipping resolved domains in debug bundle: no resolved domains")
return nil
}
resolvedDomainsContent := formatResolvedDomains(resolvedDomains, g.anonymize, g.anonymizer)
resolvedDomainsReader := strings.NewReader(resolvedDomainsContent)
if err := g.addFileToZip(resolvedDomainsReader, "resolved_domains.txt"); err != nil {
return fmt.Errorf("add resolved domains file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addNetworkMap() error {
if g.networkMap == nil {
log.Debugf("skipping empty network map in debug bundle")
@@ -482,7 +558,8 @@ func (g *BundleGenerator) addNetworkMap() error {
}
func (g *BundleGenerator) addStateFile() error {
path := statemanager.GetDefaultStatePath()
sm := profilemanager.ServiceManager{}
path := sm.GetStatePath()
if path == "" {
return nil
}
@@ -520,7 +597,8 @@ func (g *BundleGenerator) addStateFile() error {
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
pattern := statemanager.GetDefaultStatePath()
sm := profilemanager.ServiceManager{}
pattern := sm.GetStatePath()
if pattern == "" {
return nil
}
@@ -561,32 +639,7 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err)
}
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
g.addRotatedLogFiles(logDir)
stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
@@ -614,7 +667,7 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("Failed to close log file %s: %v", targetName, err)
log.Errorf("failed to close log file %s: %v", targetName, err)
}
}()
@@ -638,13 +691,21 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
if err != nil {
return fmt.Errorf("open gz log file %s: %w", targetName, err)
}
defer f.Close()
defer func() {
if err := f.Close(); err != nil {
log.Errorf("failed to close gz file %s: %v", targetName, err)
}
}()
gzr, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("create gzip reader: %w", err)
}
defer gzr.Close()
defer func() {
if err := gzr.Close(); err != nil {
log.Errorf("failed to close gzip reader %s: %v", targetName, err)
}
}()
var logReader io.Reader = gzr
if g.anonymize {
@@ -670,6 +731,51 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
return nil
}
// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount
func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
if g.logFileCount == 0 {
return
}
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
return
}
if len(files) == 0 {
return
}
// sort files by modification time (newest first)
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().After(fj.ModTime())
})
maxFiles := int(g.logFileCount)
if maxFiles > len(files) {
maxFiles = len(files)
}
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if err := g.addSingleLogFileGz(files[i], name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
@@ -684,7 +790,7 @@ func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error
// If the reader is a file, we can get more accurate information
if f, ok := reader.(*os.File); ok {
if stat, err := f.Stat(); err != nil {
log.Tracef("Failed to get file stat for %s: %v", filename, err)
log.Tracef("failed to get file stat for %s: %v", filename, err)
} else {
header.Modified = stat.ModTime()
}
@@ -732,89 +838,6 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
}
func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
var ipv4Routes, ipv6Routes []netip.Prefix
// Separate IPv4 and IPv6 routes
for _, route := range routes {
if route.Addr().Is4() {
ipv4Routes = append(ipv4Routes, route)
} else {
ipv6Routes = append(ipv6Routes, route)
}
}
// Sort IPv4 and IPv6 routes separately
sort.Slice(ipv4Routes, func(i, j int) bool {
return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
})
sort.Slice(ipv6Routes, func(i, j int) bool {
return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
})
var builder strings.Builder
// Format IPv4 routes
builder.WriteString("IPv4 Routes:\n")
for _, route := range ipv4Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
// Format IPv6 routes
builder.WriteString("\nIPv6 Routes:\n")
for _, route := range ipv6Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
return builder.String()
}
func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
if anonymize {
anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
} else {
builder.WriteString(fmt.Sprintf("%s\n", route))
}
}
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Name < interfaces[j].Name
})
var builder strings.Builder
builder.WriteString("Network Interfaces:\n")
for _, iface := range interfaces {
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
addrs, err := iface.Addrs()
if err != nil {
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
} else {
builder.WriteString(" Addresses:\n")
for _, addr := range addrs {
prefix, err := netip.ParsePrefix(addr.String())
if err != nil {
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
continue
}
ip := prefix.Addr()
if anonymize {
ip = anonymizer.AnonymizeIP(ip)
}
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
}
}
}
return builder.String()
}
func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
@@ -921,7 +944,6 @@ func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.
}
for i, ip := range peer.AllowedIps {
// Try to parse as prefix first (CIDR)
if prefix, err := netip.ParsePrefix(ip); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
@@ -1000,7 +1022,7 @@ func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.An
func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
switch record.Type {
case 1, 28: // A or AAAA record
case 1, 28:
if addr, err := netip.ParseAddr(record.RData); err == nil {
record.RData = anonymizer.AnonymizeIP(addr).String()
}

View File

@@ -17,8 +17,27 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// addIPRules collects and adds IP rules to the archive
func (g *BundleGenerator) addIPRules() error {
log.Info("Collecting IP rules")
ipRules, err := systemops.GetIPRules()
if err != nil {
return fmt.Errorf("get IP rules: %w", err)
}
rulesContent := formatIPRulesTable(ipRules, g.anonymize, g.anonymizer)
rulesReader := strings.NewReader(rulesContent)
if err := g.addFileToZip(rulesReader, "ip_rules.txt"); err != nil {
return fmt.Errorf("add IP rules file to zip: %w", err)
}
return nil
}
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
@@ -136,7 +155,6 @@ func (g *BundleGenerator) addFirewallRules() error {
func collectIPTablesRules() (string, error) {
var builder strings.Builder
// First try using iptables-save
saveOutput, err := collectIPTablesSave()
if err != nil {
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
@@ -146,7 +164,6 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n")
}
// Collect ipset information
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
@@ -232,11 +249,9 @@ func getTableStatistics(table string) (string, error) {
// collectNFTablesRules attempts to collect nftables rules using either nft command or netlink
func collectNFTablesRules() (string, error) {
// First try using nft command
rules, err := collectNFTablesFromCommand()
if err != nil {
log.Debugf("Failed to collect nftables rules using nft command: %v, falling back to netlink", err)
// Fall back to netlink
rules, err = collectNFTablesFromNetlink()
if err != nil {
return "", fmt.Errorf("collect nftables rules using both nft and netlink failed: %w", err)
@@ -451,7 +466,6 @@ func formatRule(rule *nftables.Rule) string {
func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int {
curr := exprs[i]
// Handle Meta + Cmp sequence
if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) {
if cmp, ok := exprs[i+1].(*expr.Cmp); ok {
if formatted := formatMetaWithCmp(meta, cmp); formatted != "" {
@@ -461,7 +475,6 @@ func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int {
}
}
// Handle Payload + Cmp sequence
if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) {
if cmp, ok := exprs[i+1].(*expr.Cmp); ok {
builder.WriteString(formatPayloadWithCmp(payload, cmp))
@@ -493,13 +506,13 @@ func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string {
func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
if p.Base == expr.PayloadBaseNetworkHeader {
switch p.Offset {
case 12: // Source IP
case 12:
if p.Len == 4 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
case 16: // Destination IP
case 16:
if p.Len == 4 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
@@ -580,7 +593,6 @@ func formatExpr(exp expr.Any) string {
}
func formatImmediateData(data []byte) string {
// For IP addresses (4 bytes)
if len(data) == 4 {
return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3])
}
@@ -588,26 +600,21 @@ func formatImmediateData(data []byte) string {
}
func formatMeta(e *expr.Meta) string {
// Handle source register case first (meta mark set)
if e.SourceRegister {
return fmt.Sprintf("meta %s set reg %d", formatMetaKey(e.Key), e.Register)
}
// For interface names, handle register load operation
switch e.Key {
case expr.MetaKeyIIFNAME,
expr.MetaKeyOIFNAME,
expr.MetaKeyBRIIIFNAME,
expr.MetaKeyBRIOIFNAME:
// Simply the key name with no register reference
return formatMetaKey(e.Key)
case expr.MetaKeyMARK:
// For mark operations, we want just "mark"
return "mark"
}
// For other meta keys, show as loading into register
return fmt.Sprintf("meta %s => reg %d", formatMetaKey(e.Key), e.Register)
}

View File

@@ -12,3 +12,8 @@ func (g *BundleGenerator) trySystemdLogFallback() error {
// TODO: Add BSD support
return nil
}
func (g *BundleGenerator) addIPRules() error {
// IP rules are only supported on Linux
return nil
}

View File

@@ -10,16 +10,16 @@ import (
)
func (g *BundleGenerator) addRoutes() error {
routes, err := systemops.GetRoutesFromTable()
detailedRoutes, err := systemops.GetDetailedRoutesFromTable()
if err != nil {
return fmt.Errorf("get routes: %w", err)
return fmt.Errorf("get detailed routes: %w", err)
}
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
routesContent := formatRoutesTable(detailedRoutes, g.anonymize, g.anonymizer)
routesReader := strings.NewReader(routesContent)
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
return nil
}

View File

@@ -0,0 +1,206 @@
package debug
import (
"fmt"
"net"
"net/netip"
"sort"
"strings"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/management/domain"
)
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Name < interfaces[j].Name
})
var builder strings.Builder
builder.WriteString("Network Interfaces:\n")
for _, iface := range interfaces {
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
addrs, err := iface.Addrs()
if err != nil {
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
} else {
builder.WriteString(" Addresses:\n")
for _, addr := range addrs {
prefix, err := netip.ParsePrefix(addr.String())
if err != nil {
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
continue
}
ip := prefix.Addr()
if anonymize {
ip = anonymizer.AnonymizeIP(ip)
}
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
}
}
}
return builder.String()
}
func formatResolvedDomains(resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, anonymize bool, anonymizer *anonymize.Anonymizer) string {
if len(resolvedDomains) == 0 {
return "No resolved domains found.\n"
}
var builder strings.Builder
builder.WriteString("Resolved Domains:\n")
builder.WriteString("=================\n\n")
var sortedParents []domain.Domain
for parentDomain := range resolvedDomains {
sortedParents = append(sortedParents, parentDomain)
}
sort.Slice(sortedParents, func(i, j int) bool {
return sortedParents[i].SafeString() < sortedParents[j].SafeString()
})
for _, parentDomain := range sortedParents {
info := resolvedDomains[parentDomain]
parentKey := parentDomain.SafeString()
if anonymize {
parentKey = anonymizer.AnonymizeDomain(parentKey)
}
builder.WriteString(fmt.Sprintf("%s:\n", parentKey))
var sortedIPs []string
for _, prefix := range info.Prefixes {
ipStr := prefix.String()
if anonymize {
anonymizedIP := anonymizer.AnonymizeIP(prefix.Addr())
ipStr = fmt.Sprintf("%s/%d", anonymizedIP, prefix.Bits())
}
sortedIPs = append(sortedIPs, ipStr)
}
sort.Strings(sortedIPs)
for _, ipStr := range sortedIPs {
builder.WriteString(fmt.Sprintf(" %s\n", ipStr))
}
builder.WriteString("\n")
}
return builder.String()
}
func formatRoutesTable(detailedRoutes []systemops.DetailedRoute, anonymize bool, anonymizer *anonymize.Anonymizer) string {
if len(detailedRoutes) == 0 {
return "No routes found.\n"
}
sort.Slice(detailedRoutes, func(i, j int) bool {
if detailedRoutes[i].Table != detailedRoutes[j].Table {
return detailedRoutes[i].Table < detailedRoutes[j].Table
}
return detailedRoutes[i].Route.Dst.String() < detailedRoutes[j].Route.Dst.String()
})
headers, rows := buildPlatformSpecificRouteTable(detailedRoutes, anonymize, anonymizer)
return formatTable("Routing Table:", headers, rows)
}
func formatRouteDestination(destination netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
if anonymize {
anonymizedDestIP := anonymizer.AnonymizeIP(destination.Addr())
return fmt.Sprintf("%s/%d", anonymizedDestIP, destination.Bits())
}
return destination.String()
}
func formatRouteGateway(gateway netip.Addr, anonymize bool, anonymizer *anonymize.Anonymizer) string {
if gateway.IsValid() {
if anonymize {
return anonymizer.AnonymizeIP(gateway).String()
}
return gateway.String()
}
return "-"
}
func formatRouteInterface(iface *net.Interface) string {
if iface != nil {
return iface.Name
}
return "-"
}
func formatInterfaceIndex(index int) string {
if index <= 0 {
return "-"
}
return fmt.Sprintf("%d", index)
}
func formatRouteMetric(metric int) string {
if metric < 0 {
return "-"
}
return fmt.Sprintf("%d", metric)
}
func formatTable(title string, headers []string, rows [][]string) string {
widths := make([]int, len(headers))
for i, header := range headers {
widths[i] = len(header)
}
for _, row := range rows {
for i, cell := range row {
if len(cell) > widths[i] {
widths[i] = len(cell)
}
}
}
for i := range widths {
widths[i] += 2
}
var formatParts []string
for _, width := range widths {
formatParts = append(formatParts, fmt.Sprintf("%%-%ds", width))
}
formatStr := strings.Join(formatParts, "") + "\n"
var builder strings.Builder
builder.WriteString(title + "\n")
builder.WriteString(strings.Repeat("=", len(title)) + "\n\n")
headerArgs := make([]interface{}, len(headers))
for i, header := range headers {
headerArgs[i] = header
}
builder.WriteString(fmt.Sprintf(formatStr, headerArgs...))
separatorArgs := make([]interface{}, len(headers))
for i, width := range widths {
separatorArgs[i] = strings.Repeat("-", width-2)
}
builder.WriteString(fmt.Sprintf(formatStr, separatorArgs...))
for _, row := range rows {
rowArgs := make([]interface{}, len(row))
for i, cell := range row {
rowArgs[i] = cell
}
builder.WriteString(fmt.Sprintf(formatStr, rowArgs...))
}
return builder.String()
}

View File

@@ -0,0 +1,185 @@
//go:build linux && !android
package debug
import (
"fmt"
"net/netip"
"sort"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func formatIPRulesTable(ipRules []systemops.IPRule, anonymize bool, anonymizer *anonymize.Anonymizer) string {
if len(ipRules) == 0 {
return "No IP rules found.\n"
}
sort.Slice(ipRules, func(i, j int) bool {
return ipRules[i].Priority < ipRules[j].Priority
})
columnConfig := detectIPRuleColumns(ipRules)
headers := buildIPRuleHeaders(columnConfig)
rows := buildIPRuleRows(ipRules, columnConfig, anonymize, anonymizer)
return formatTable("IP Rules:", headers, rows)
}
type ipRuleColumnConfig struct {
hasInvert, hasTo, hasMark, hasIIF, hasOIF, hasSuppressPlen bool
}
func detectIPRuleColumns(ipRules []systemops.IPRule) ipRuleColumnConfig {
var config ipRuleColumnConfig
for _, rule := range ipRules {
if rule.Invert {
config.hasInvert = true
}
if rule.To.IsValid() {
config.hasTo = true
}
if rule.Mark != 0 {
config.hasMark = true
}
if rule.IIF != "" {
config.hasIIF = true
}
if rule.OIF != "" {
config.hasOIF = true
}
if rule.SuppressPlen >= 0 {
config.hasSuppressPlen = true
}
}
return config
}
func buildIPRuleHeaders(config ipRuleColumnConfig) []string {
var headers []string
headers = append(headers, "Priority")
if config.hasInvert {
headers = append(headers, "Not")
}
headers = append(headers, "From")
if config.hasTo {
headers = append(headers, "To")
}
if config.hasMark {
headers = append(headers, "FWMark")
}
if config.hasIIF {
headers = append(headers, "IIF")
}
if config.hasOIF {
headers = append(headers, "OIF")
}
headers = append(headers, "Table")
headers = append(headers, "Action")
if config.hasSuppressPlen {
headers = append(headers, "SuppressPlen")
}
return headers
}
func buildIPRuleRows(ipRules []systemops.IPRule, config ipRuleColumnConfig, anonymize bool, anonymizer *anonymize.Anonymizer) [][]string {
var rows [][]string
for _, rule := range ipRules {
row := buildSingleIPRuleRow(rule, config, anonymize, anonymizer)
rows = append(rows, row)
}
return rows
}
func buildSingleIPRuleRow(rule systemops.IPRule, config ipRuleColumnConfig, anonymize bool, anonymizer *anonymize.Anonymizer) []string {
var row []string
row = append(row, fmt.Sprintf("%d", rule.Priority))
if config.hasInvert {
row = append(row, formatIPRuleInvert(rule.Invert))
}
row = append(row, formatIPRuleAddress(rule.From, "all", anonymize, anonymizer))
if config.hasTo {
row = append(row, formatIPRuleAddress(rule.To, "-", anonymize, anonymizer))
}
if config.hasMark {
row = append(row, formatIPRuleMark(rule.Mark, rule.Mask))
}
if config.hasIIF {
row = append(row, formatIPRuleInterface(rule.IIF))
}
if config.hasOIF {
row = append(row, formatIPRuleInterface(rule.OIF))
}
row = append(row, rule.Table)
row = append(row, formatIPRuleAction(rule.Action))
if config.hasSuppressPlen {
row = append(row, formatIPRuleSuppressPlen(rule.SuppressPlen))
}
return row
}
func formatIPRuleInvert(invert bool) string {
if invert {
return "not"
}
return "-"
}
func formatIPRuleAction(action string) string {
if action == "unspec" {
return "lookup"
}
return action
}
func formatIPRuleSuppressPlen(suppressPlen int) string {
if suppressPlen >= 0 {
return fmt.Sprintf("%d", suppressPlen)
}
return "-"
}
func formatIPRuleAddress(prefix netip.Prefix, defaultVal string, anonymize bool, anonymizer *anonymize.Anonymizer) string {
if !prefix.IsValid() {
return defaultVal
}
if anonymize {
anonymizedIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonymizedIP, prefix.Bits())
}
return prefix.String()
}
func formatIPRuleMark(mark, mask uint32) string {
if mark == 0 {
return "-"
}
if mask != 0 {
return fmt.Sprintf("0x%x/0x%x", mark, mask)
}
return fmt.Sprintf("0x%x", mark)
}
func formatIPRuleInterface(iface string) string {
if iface == "" {
return "-"
}
return iface
}

View File

@@ -0,0 +1,27 @@
//go:build !windows
package debug
import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// buildPlatformSpecificRouteTable builds headers and rows for non-Windows platforms
func buildPlatformSpecificRouteTable(detailedRoutes []systemops.DetailedRoute, anonymize bool, anonymizer *anonymize.Anonymizer) ([]string, [][]string) {
headers := []string{"Destination", "Gateway", "Interface", "Idx", "Metric", "Protocol", "Scope", "Type", "Table", "Flags"}
var rows [][]string
for _, route := range detailedRoutes {
destStr := formatRouteDestination(route.Route.Dst, anonymize, anonymizer)
gatewayStr := formatRouteGateway(route.Route.Gw, anonymize, anonymizer)
interfaceStr := formatRouteInterface(route.Route.Interface)
indexStr := formatInterfaceIndex(route.InterfaceIndex)
metricStr := formatRouteMetric(route.Metric)
row := []string{destStr, gatewayStr, interfaceStr, indexStr, metricStr, route.Protocol, route.Scope, route.Type, route.Table, route.Flags}
rows = append(rows, row)
}
return headers, rows
}

View File

@@ -0,0 +1,37 @@
//go:build windows
package debug
import (
"fmt"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// buildPlatformSpecificRouteTable builds headers and rows for Windows with interface metrics
func buildPlatformSpecificRouteTable(detailedRoutes []systemops.DetailedRoute, anonymize bool, anonymizer *anonymize.Anonymizer) ([]string, [][]string) {
headers := []string{"Destination", "Gateway", "Interface", "Idx", "Metric", "If Metric", "Protocol", "Age", "Origin"}
var rows [][]string
for _, route := range detailedRoutes {
destStr := formatRouteDestination(route.Route.Dst, anonymize, anonymizer)
gatewayStr := formatRouteGateway(route.Route.Gw, anonymize, anonymizer)
interfaceStr := formatRouteInterface(route.Route.Interface)
indexStr := formatInterfaceIndex(route.InterfaceIndex)
metricStr := formatRouteMetric(route.Metric)
ifMetricStr := formatInterfaceMetric(route.InterfaceMetric)
row := []string{destStr, gatewayStr, interfaceStr, indexStr, metricStr, ifMetricStr, route.Protocol, route.Scope, route.Type}
rows = append(rows, row)
}
return headers, rows
}
func formatInterfaceMetric(metric int) string {
if metric < 0 {
return "-"
}
return fmt.Sprintf("%d", metric)
}

View File

@@ -4,8 +4,8 @@ package dns
import (
"fmt"
"net/netip"
"os"
"regexp"
"strings"
log "github.com/sirupsen/logrus"
@@ -15,9 +15,6 @@ const (
defaultResolvConfPath = "/etc/resolv.conf"
)
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
type resolvConf struct {
nameServers []string
searchDomains []string
@@ -108,40 +105,9 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
return rconf, nil
}
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
// otherwise it adds a new option with timeout and attempts.
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
configs := make([]string, len(input))
copy(configs, input)
for i, config := range configs {
if strings.HasPrefix(config, "options") {
config = strings.ReplaceAll(config, "rotate", "")
config = strings.Join(strings.Fields(config), " ")
if strings.Contains(config, "timeout:") {
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
}
if strings.Contains(config, "attempts:") {
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
}
configs[i] = config
return configs
}
}
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
}
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename, nameserverIP string) error {
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
@@ -151,7 +117,7 @@ func removeFirstNbNameserver(filename, nameserverIP string) error {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)

Some files were not shown because too many files have changed in this diff Show More