Compare commits

..

116 Commits

Author SHA1 Message Date
Zoltán Papp
619f1588b3 Add ldflags for wasm 2025-11-19 11:48:41 +01:00
Zoltán Papp
60d2a2c7df Fix error handling 2025-11-19 11:02:44 +01:00
Zoltán Papp
ff9585735b Fix permission checks 2025-11-19 10:54:50 +01:00
Zoltán Papp
194951c88d Typo fix 2025-11-19 10:43:41 +01:00
Zoltán Papp
d63f2e5196 Typo fix 2025-11-18 19:16:14 +01:00
Zoltán Papp
14b5637555 Fix error message 2025-11-18 19:15:18 +01:00
Zoltán Papp
3d3b05c157 Disable lint for deprecated line 2025-11-18 19:11:40 +01:00
Zoltán Papp
f36e206238 Fix account moc 2025-11-18 18:54:08 +01:00
Zoltán Papp
9a808244d7 Fix lint issue 2025-11-18 18:46:53 +01:00
Zoltán Papp
2dec76f8ea Fix test after merge conflict 2025-11-18 18:39:14 +01:00
Zoltán Papp
224bd8ff22 Merge branch 'main' into feature/remote-debug-clean
# Conflicts:
#	client/cmd/testutil_test.go
#	client/internal/engine_test.go
#	client/server/server.go
#	client/server/server_test.go
#	client/status/status.go
#	go.mod
#	go.sum
#	management/internals/server/boot.go
#	management/internals/server/modules.go
#	management/internals/shared/grpc/server.go
#	management/server/account.go
#	management/server/account/manager.go
#	management/server/account_test.go
#	management/server/dns_test.go
#	management/server/http/testing/testing_tools/channel/channel.go
#	management/server/management_proto_test.go
#	management/server/management_test.go
#	management/server/nameserver_test.go
#	management/server/peer_test.go
#	management/server/route_test.go
#	shared/management/client/client_test.go
#	shared/management/proto/management.pb.go
2025-11-18 18:30:48 +01:00
Zoltán Papp
fe88a5662e Fix log message 2025-11-18 17:33:54 +01:00
Zoltán Papp
f9f6409f94 Remove context log from grpc client 2025-11-18 17:26:33 +01:00
Zoltán Papp
b03154dce5 Use dedicated ctx in stream 2025-11-18 17:25:46 +01:00
Zoltán Papp
c57364596a Fix log message 2025-11-18 17:15:16 +01:00
Viktor Liu
60f4d5f9b0 [client] Revert migrate deprecated grpc client code #4805 2025-11-18 12:41:17 +01:00
Vlad
4eeb2d8deb [management] added exception on not appending route firewall rules if we have all wildcard (#4801) 2025-11-17 18:20:30 +01:00
Zoltán Papp
2765bcfb89 Restore in case of error 2025-11-17 17:24:19 +01:00
Viktor Liu
d71a82769c [client,management] Rewrite the SSH feature (#4015) 2025-11-17 17:10:41 +01:00
Zoltán Papp
fa6151b849 Fix error message 2025-11-17 16:13:29 +01:00
Zoltán Papp
a939c1767c Fix error message 2025-11-17 16:09:26 +01:00
Zoltán Papp
938554fb0f Implement time for parameter usage 2025-11-17 15:47:27 +01:00
Misha Bragin
0d79301141 Update client login success page (#4797) 2025-11-17 15:28:20 +01:00
Zoltán Papp
39bec2dd74 Truncate too long error response 2025-11-17 14:01:21 +01:00
Zoltán Papp
554c9bcf4b Do not expect last sync for debug bundle 2025-11-17 12:04:20 +01:00
Zoltán Papp
f3639675e7 Fix string conversation 2025-11-17 12:02:58 +01:00
Zoltán Papp
a1457f541b Handle job error responses on mgm side 2025-11-17 11:49:46 +01:00
Zoltán Papp
9cdfb0d78c Handle unimplemented job type 2025-11-17 11:45:12 +01:00
Zoltán Papp
22d796097e Add profile name and events for status 2025-11-17 09:46:54 +01:00
Zoltán Papp
aa39a5d528 Add logs 2025-11-14 13:22:08 +01:00
Viktor Liu
e4b41d0ad7 [client] Replace ipset lib (#4777)
* Replace ipset lib

* Update .github/workflows/check-license-dependencies.yml

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Ignore internal licenses

* Ignore dependencies from AGPL code

* Use exported errors

* Use fixed version

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-11-14 00:25:00 +01:00
Viktor Liu
9cc9462cd5 [client] Use stdnet with a context to avoid DNS deadlocks (#4781) 2025-11-13 20:16:45 +01:00
Diego Romar
3176b53968 [client] Add quick actions window (#4717)
* Open quick settings window if netbird-ui is already running

* [client-ui] fix connection status comparison

* [client-ui] modularize quick actions code

* [client-ui] add netbird-disconnected logo

* [client-ui] change quickactions UI

It now displays the NetBird logo and a single button
with a round icon

* [client-ui] add hint message to quick actions screen

This also updates fyne to v2.7.0

* [client-ui] remove unnecessary default clause

* [client-ui] remove commented code

* [client-ui] remove unused dependency

* [client-ui] close quick actions on connection change

* [client-ui] add function to get image from embed resources

* [client] Return error when calling sendShowWindowSignal from Windows

* [client-ui] Add commentary on empty OnTapped function for toggleConnectionButton

* [client-ui] Fix tests

* [client-ui] Add context to menuUpClick call

* [client-ui] Pass serviceClient app as parameter

To use its clipboard rather than the window's when showing
the upload success dialog

* [client-ui] Replace for select with for range chan

* [client-ui] Replace settings change listener channel

Settings now accept a function callback

* [client-ui] Add missing iconAboutDisconnected to icons_windows.go

* [client] Add quick actions signal handler for Windows with named events

* [client] Run go mod tidy

* [client] Remove line break

* [client] Log unexpected status in separate function

* [client-ui] Refactor quick actions window

To address racing conditions, it also replaces
usage of pause and resume channels with an
atomic bool.

* [client-ui] use derived context from ServiceClient

* [client] Update signal_windows log message

Also, format error when trying to set event on
sendShowWindowSignal

* go mod tidy

* [client-ui] Add struct to pass fewer parameters

to applyQuickActionsUiState function

* [client] Add missing import

---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
2025-11-13 10:25:19 -03:00
Viktor Liu
27957036c9 [client] Fix shutdown blocking on stuck ICE agent close (#4780) 2025-11-13 13:24:51 +01:00
Pascal Fischer
6fb568728f [management] Removed policy posture checks on original peer (#4779)
Co-authored-by: crn4 <vladimir@netbird.io>
2025-11-13 12:51:03 +01:00
Pascal Fischer
cc97cffff1 [management] move network map logic into new design (#4774) 2025-11-13 12:09:46 +01:00
aliamerj
1d2a5371ce remove EOF skipping 2025-11-12 18:48:53 +03:00
Zoltán Papp
6898e57686 Fix nil pointer check 2025-11-12 13:36:32 +01:00
Zoltán Papp
c8bc865f2f Fix error message 2025-11-12 13:32:14 +01:00
Zoltán Papp
06bb8658b1 Fix validation 2025-11-12 13:30:29 +01:00
Zoltán Papp
8fc4fed3a0 Fix SQL query syntax 2025-11-12 13:20:08 +01:00
Zoltán Papp
df14f1399f Fix MockAccountManager function calls checks 2025-11-12 13:18:40 +01:00
Zoltán Papp
6d6f090764 Merge branch 'main' into feature/remote-debug-clean 2025-11-12 12:48:20 +01:00
Zoltan Papp
c28275611b Fix agent reference (#4776) 2025-11-11 13:59:32 +01:00
Vlad
56f169eede [management] fix pg db deadlock after app panic (#4772) 2025-11-10 23:43:08 +01:00
Viktor Liu
07cf9d5895 [client] Create networkd.conf.d if it doesn't exist (#4764) 2025-11-08 10:54:37 +01:00
Pascal Fischer
7df49e249d [management ] remove timing logs (#4761) 2025-11-07 20:14:52 +01:00
Pascal Fischer
dbfc8a52c9 [management] remove GLOBAL when disabling foreign keys on mysql (#4615) 2025-11-07 16:03:14 +01:00
Vlad
98ddac07bf [management] remove toAll firewall rule (#4725) 2025-11-07 15:50:58 +01:00
Pascal Fischer
48475ddc05 [management] add pat rate limiting (#4741) 2025-11-07 15:50:18 +01:00
Vlad
6aa4ba7af4 [management] incremental network map builder (#4753) 2025-11-07 10:44:46 +01:00
dependabot[bot]
2e16c9914a [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29.
- [Release notes](https://github.com/containerd/containerd/releases)
- [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md)
- [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29)

---
updated-dependencies:
- dependency-name: github.com/containerd/containerd
  dependency-version: 1.7.29
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-06 19:01:44 +03:00
Pascal Fischer
5c29d395b2 [management] activity events on group updates (#4750) 2025-11-06 12:51:14 +01:00
Viktor Liu
229e0038ee [client] Add dns config to debug bundle (#4704) 2025-11-05 17:30:17 +01:00
Viktor Liu
75327d9519 [client] Add login_hint to oidc flows (#4724) 2025-11-05 17:00:20 +01:00
Viktor Liu
c92e6c1b5f [client] Block on all subsystems on shutdown (#4709) 2025-11-05 12:15:37 +01:00
Viktor Liu
641eb5140b [client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742) 2025-11-04 21:56:53 +01:00
Viktor Liu
45c25dca84 [client] Clamp MSS on outbound traffic (#4735) 2025-11-04 17:18:51 +01:00
Viktor Liu
679c58ce47 [client] Set up networkd to ignore ip rules (#4730) 2025-11-04 17:06:35 +01:00
aliamerj
13febbbfca update event metadata for creating new job 2025-11-04 15:24:46 +03:00
aliamerj
49d36b7e7e rename logFile to logPath 2025-11-04 10:09:34 +03:00
Pascal Fischer
719283c792 [management] update db connection lifecycle configuration (#4740) 2025-11-03 17:40:12 +01:00
dependabot[bot]
a2313a5ba4 [client] Bump github.com/quic-go/quic-go from 0.48.2 to 0.49.1 (#4621)
Bumps [github.com/quic-go/quic-go](https://github.com/quic-go/quic-go) from 0.48.2 to 0.49.1.
- [Release notes](https://github.com/quic-go/quic-go/releases)
- [Commits](https://github.com/quic-go/quic-go/compare/v0.48.2...v0.49.1)

---
updated-dependencies:
- dependency-name: github.com/quic-go/quic-go
  dependency-version: 0.49.1
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-01 15:27:22 +01:00
Zoltan Papp
8c108ccad3 [client] Extend Darwin network monitoring with wakeup detection 2025-10-31 19:19:14 +01:00
Viktor Liu
86eff0d750 [client] Fix netstack dns forwarder (#4727) 2025-10-31 14:18:09 +01:00
aliamerj
976787dbf1 Merge remote-tracking branch 'upstream/main' into feature/remote-debug-clean 2025-10-30 14:30:00 +03:00
Viktor Liu
43c9a51913 [client] Migrate deprecated grpc client code (#4687) 2025-10-30 10:14:27 +01:00
Viktor Liu
c530db1455 [client] Fix UI panic when switching profiles (#4718) 2025-10-29 17:27:18 +01:00
Viktor Liu
1ee575befe [client] Use management-provided dns forwarder port on the client side (#4712) 2025-10-28 22:58:43 +01:00
Viktor Liu
d3a34adcc9 [client] Fix Connect/Disconnect buttons being enabled or disabled at the same time (#4711) 2025-10-28 21:21:40 +01:00
Zoltan Papp
d7321c130b [client] The status cmd will not be blocked by the ICE probe (#4597)
The status cmd will not be blocked by the ICE probe

Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state.
2025-10-28 16:11:35 +01:00
Viktor Liu
404cab90ba [client] Redirect dns forwarder port 5353 to new listening port 22054 (#4707)
- Port dnat changes from https://github.com/netbirdio/netbird/pull/4015 (nftables/iptables/userspace)
  - For userspace: rewrite the original port to the target port
  - Remember original destination port in conntrack
  - Rewrite the source port back to the original port for replies
- Redirect incoming port 5353 to 22054 (tcp/udp)
- Revert port changes based on the network map received from management
- Adjust tracer to show NAT stages
2025-10-28 15:12:53 +01:00
Pascal Fischer
4545ab9a52 [management] rewire account manager to permissions manager (#4673) 2025-10-27 22:59:35 +01:00
Bethuel Mmbaga
7f08983207 Include expired and routing peers in DNS record filtering (#4708) 2025-10-27 22:16:17 +03:00
Viktor Liu
eddea14521 [client] Clean up bsd routes independently of the state file (#4688) 2025-10-27 18:54:00 +01:00
Viktor Liu
b9ef214ea5 [client] Fix macOS state-based dns cleanup (#4701) 2025-10-27 18:35:32 +01:00
Bethuel Mmbaga
709e24eb6f [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644)
This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS.
2025-10-24 15:40:20 +03:00
Viktor Liu
6654e2dbf7 [client] Fix active profile name in debug bundle (#4689) 2025-10-23 17:07:52 +02:00
Bethuel Mmbaga
d80d47a469 [management] Add peer disapproval reason (#4468) 2025-10-22 12:46:22 +03:00
Maycon Santos
96f71ff1e1 [misc] Update tag name extraction in install.sh (#4677) 2025-10-21 19:23:11 +02:00
Viktor Liu
2fe2af38d2 [client] Clean up match domain reg entries between config changes (#4676) 2025-10-21 18:14:39 +02:00
aliamerj
536b0003ab handle unimplemented Job stream 2025-10-18 17:10:51 +03:00
Misha Bragin
cd9a867ad0 [client] Delete TURNConfig section from script (#4639) 2025-10-17 19:48:26 +02:00
Maycon Santos
0f9bfeff7c [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 2025-10-17 19:47:11 +02:00
Viktor Liu
f5301230bf [client] Fix status showing P2P without connection (#4661) 2025-10-17 13:31:15 +02:00
Viktor Liu
429d7d6585 [client] Support BROWSER env for login (#4654) 2025-10-17 11:10:16 +02:00
Viktor Liu
3cdb10cde7 [client] Remove rule squashing (#4653) 2025-10-17 11:09:39 +02:00
Zoltan Papp
af95aabb03 Handle the case when the service has already been down and the status recorder is not available (#4652) 2025-10-16 17:15:39 +02:00
Viktor Liu
3abae0bd17 [client] Set default wg port for new profiles (#4651) 2025-10-16 16:16:51 +02:00
Viktor Liu
8252ff41db [client] Add bind activity listener to bypass udp sockets (#4646) 2025-10-16 15:58:29 +02:00
Viktor Liu
277aa2b7cc [client] Fix missing flag values in profiles (#4650) 2025-10-16 15:13:41 +02:00
John Conley
bb37dc89ce [management] feat: Basic PocketID IDP integration (#4529) 2025-10-16 10:46:29 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
aliamerj
0e9438d658 fix lint 2025-10-10 18:00:15 +03:00
aliamerj
e570570fe5 add some info level log 2025-10-10 17:47:15 +03:00
aliamerj
23f9dd04b8 clean up 2025-10-10 17:17:35 +03:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
aliamerj
7a95bf5652 fix bug with missing logs file 2025-10-10 17:14:08 +03:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
hakansa
d35a845dbd [management] sync all other peers on peer add/remove (#4614) 2025-10-09 21:18:00 +02:00
hakansa
4e03f708a4 fix dns forwarder port update (#4613)
fix dns forwarder port update (#4613)
2025-10-09 17:39:02 +03:00
Ashley
654aa9581d [client,gui] Update url_windows.go to offer arm64 executable download (#4586) 2025-10-08 21:27:32 +02:00
Zoltan Papp
9021bb512b [client] Recreate agent when receive new session id (#4564)
When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time.
In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started.
2025-10-08 17:14:24 +02:00
hakansa
768332820e [client] Implement DNS query caching in DNSForwarder (#4574)
implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails.

- Added a new cache component to store DNS query results by domain and query type
- Integrated cache storage after successful DNS resolutions
- Enhanced error handling to serve cached responses as fallback when upstream DNS fails
2025-10-08 16:54:27 +02:00
hakansa
229c65ffa1 Enhance showLoginURL to include connection status check and auto-close functionality (#4525) 2025-10-08 12:42:15 +02:00
Zoltan Papp
4d33567888 [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228)
* When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder.
2025-10-08 03:12:16 +02:00
Viktor Liu
88467883fc [management,signal] Remove ws-proxy read deadline (#4598) 2025-10-06 22:05:48 +02:00
Viktor Liu
954f40991f [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) 2025-10-06 21:22:19 +02:00
Maycon Santos
34341d95a9 Adjust signal port for websocket connections (#4594) 2025-10-06 15:22:02 -03:00
aliamerj
60c5782905 fix other lint issue 2025-10-06 17:56:48 +03:00
aliamerj
5a12c5d345 fix everything 2025-10-06 17:45:36 +03:00
Ali Amer
bdae55ab79 update yml for job (#4532) 2025-10-06 16:13:10 +03:00
Ali Amer
d01c3d5011 [management/client] Integrate Job API with Job Stream and Client Engine (#4428)
* integrate api

integrate api with stream and implement some client side

* fix typo and fix validation

* use real daemon address

* redo the connect via address

* Refactor the debug bundle generator to be ready to use from engine (#4469)

* fix tests

* fix lint

* fix bug with stream

* try refactor status 1

* fix convert fullStatus to statusOutput & add logFile

* fix tests

* fix tests

* fix not enough arguments in call to nbstatus.ConvertToStatusOutputOverview

* fix status_test

* fix(engine): avoid deadlock when stopping engine during debug bundle

* use atomic for lock-free

* use new lock

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2025-10-06 16:13:07 +03:00
Ali Amer
17a2af96ea implement remote debug api (#4418)
fix lint

clean up

fix MarkPendingJobsAsFailed

apply feedbacks 1

fix typo

change api and apply new schema

fix lint

fix api object

clean switch case

apply feedback 2

fix error handle in create job

get rid of any/interface type in job database

fix sonar issue

use RawJson for both parameters and results

running go mod tidy

update package

fix 1

update codegen

fix code-gen

fix snyk

fix snyk hopefully
2025-10-06 16:05:27 +03:00
Ali Amer
cc595da1ad [management/client] create job channel between management and client (#4367)
* new bi-directional stream for jobs

* create bidirectional job channel to send requests from the server and receive responses from the client

* fix tests

* fix lint and close bug

* fix lint

* clean up & fix close of closed channel

* add nolint:staticcheck

* remove some redundant code from the job channel PR since this one is a cleaner rewrite

* cleanup removes a pending job safely

* change proto

* rename to jobRequest

* apply feedback 1

* apply feedback 2

* fix typo

* apply feedback 3

* apply last feedback
2025-10-06 15:56:04 +03:00
556 changed files with 44703 additions and 10037 deletions

View File

@@ -3,39 +3,108 @@ name: Check License Dependencies
on:
push:
branches: [ main ]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs:
check-dependencies:
check-internal-dependencies:
name: Check Internal AGPL Dependencies
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ ! -z "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
done
if [ $FOUND_ISSUES -eq 1 ]; then
echo ""
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages will change license and should not be imported by client or shared code"
exit 1
else
echo ""
echo "✅ All license dependencies are clean"
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"

View File

@@ -16,19 +16,29 @@ jobs:
steps:
- name: Read PR body
id: body
shell: bash
run: |
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
echo "body<<EOF" >> $GITHUB_OUTPUT
echo "$BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
set -euo pipefail
BODY_B64=$(jq -r '.pull_request.body // "" | @base64' "$GITHUB_EVENT_PATH")
{
echo "body_b64=$BODY_B64"
} >> "$GITHUB_OUTPUT"
- name: Validate checkbox selection
id: validate
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: |
body='${{ steps.body.outputs.body }}'
set -euo pipefail
if ! body="$(printf '%s' "$BODY_B64" | base64 -d)"; then
echo "::error::Failed to decode PR body from base64. Data may be corrupted or missing."
exit 1
fi
added_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*I added/updated documentation' | wc -l | tr -d '[:space:]' || true)
noneed_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*Documentation is \*\*not needed\*\*' | wc -l | tr -d '[:space:]' || true)
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
@@ -41,30 +51,35 @@ jobs:
fi
if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> $GITHUB_OUTPUT
echo "mode=added" >> "$GITHUB_OUTPUT"
else
echo "mode=noneed" >> $GITHUB_OUTPUT
echo "mode=noneed" >> "$GITHUB_OUTPUT"
fi
- name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added'
id: extract
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: |
body='${{ steps.body.outputs.body }}'
set -euo pipefail
body="$(printf '%s' "$BODY_B64" | base64 -d)"
# Strictly require HTTPS and that it's a PR in netbirdio/docs
# Examples accepted:
# https://github.com/netbirdio/docs/pull/1234
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
# e.g., https://github.com/netbirdio/docs/pull/1234
url="$(printf '%s' "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)"
if [ -z "$url" ]; then
if [ -z "${url:-}" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1
fi
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
echo "url=$url" >> $GITHUB_OUTPUT
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
pr_number="$(printf '%s' "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')"
{
echo "url=$url"
echo "pr_number=$pr_number"
} >> "$GITHUB_OUTPUT"
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: ""
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -382,6 +382,32 @@ jobs:
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -428,9 +454,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/...
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
api_benchmark:
name: "Management / Benchmark (API)"
@@ -521,7 +548,7 @@ jobs:
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/...
-timeout 20m ./management/server/http/...
api_integration_test:
name: "Management / Integration"
@@ -571,4 +598,4 @@ jobs:
CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/...
-timeout 20m ./management/server/http/...

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.22"
SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -47,7 +47,7 @@ jobs:
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
run: GOOS=js GOARCH=wasm go build -o netbird.wasm -ldflags="-s -w" ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size

View File

@@ -1,3 +1,4 @@
<div align="center">
<br/>
<br/>
@@ -52,7 +53,7 @@
### Open Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -4,6 +4,7 @@ package android
import (
"context"
"os"
"slices"
"sync"
@@ -16,9 +17,9 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net"
)
// ConnectionListener export internal Listener for mobile
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -112,13 +114,14 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "")
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -138,7 +141,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "")
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -0,0 +1,32 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -33,6 +33,7 @@ type ErrListener interface {
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
@@ -194,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
if err != nil {
return nil, err
}

View File

@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetEnableSSHRoot reads SSH root login setting from config file
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
if p.configInput.EnableSSHRoot != nil {
return *p.configInput.EnableSSHRoot, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHRoot == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHRoot, err
}
// SetEnableSSHRoot stores the given value and waits for commit
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
p.configInput.EnableSSHRoot = &enabled
}
// GetEnableSSHSFTP reads SSH SFTP setting from config file
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
if p.configInput.EnableSSHSFTP != nil {
return *p.configInput.EnableSSHSFTP, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHSFTP == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHSFTP, err
}
// SetEnableSSHSFTP stores the given value and waits for commit
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
p.configInput.EnableSSHSFTP = &enabled
}
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
if p.configInput.EnableSSHLocalPortForwarding != nil {
return *p.configInput.EnableSSHLocalPortForwarding, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHLocalPortForwarding == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHLocalPortForwarding, err
}
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
p.configInput.EnableSSHLocalPortForwarding = &enabled
}
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
if p.configInput.EnableSSHRemotePortForwarding != nil {
return *p.configInput.EnableSSHRemotePortForwarding, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHRemotePortForwarding == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHRemotePortForwarding, err
}
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
p.configInput.EnableSSHRemotePortForwarding = &enabled
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -168,7 +166,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
@@ -220,9 +218,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
@@ -230,11 +225,8 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -301,19 +293,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp.GetFullStatus(), anon, statusResp.GetDaemonVersion(), "", nil, nil, nil, "", ""),
)
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -372,7 +351,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config,
StatusRecorder: recorder,
SyncResponse: syncResponse,
LogFile: logFilePath,
LogPath: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
@@ -105,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
@@ -227,7 +235,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
}
// update host's static platform and system information
system.UpdateStaticInfo()
system.UpdateStaticInfoAsync()
configFilePath, err := activeProf.FilePath()
if err != nil {
@@ -240,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -268,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
needsLogin := false
err := WithBackOff(func() error {
@@ -285,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -314,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
if err != nil {
return nil, err
}
@@ -356,13 +373,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil {
if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -35,10 +35,10 @@ const (
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection"
mtuFlag = "mtu"
)
var (
@@ -63,7 +63,6 @@ var (
customDNSAddress string
rosenpassEnabled bool
rosenpassPermissive bool
serverSSHAllowed bool
interfaceName string
wireguardPort uint16
networkMonitor bool
@@ -72,6 +71,7 @@ var (
anonymizeFlag bool
dnsRouteInterval time.Duration
lazyConnEnabled bool
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
@@ -174,7 +174,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
@@ -229,7 +228,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
return grpc.DialContext(

View File

@@ -54,6 +54,7 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
cmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfo()
system.UpdateStaticInfoAsync()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()

View File

@@ -10,6 +10,8 @@ import (
"path/filepath"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/kardianos/service"
"github.com/spf13/cobra"
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
svcConfig.Option["LogDirectory"] = dir
}
}
if err := configureSystemdNetworkd(); err != nil {
log.Warnf("failed to configure systemd-networkd: %v", err)
}
}
if runtime.GOOS == "windows" {
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
return fmt.Errorf("uninstall service: %w", err)
}
if runtime.GOOS == "linux" {
if err := cleanupSystemdNetworkd(); err != nil {
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
}
}
cmd.Println("NetBird service has been uninstalled")
return nil
},
@@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) {
return status == service.StatusRunning, nil
}
const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
# routes and policy rules managed by NetBird.
[Network]
ManageForeignRoutes=no
ManageForeignRoutingPolicyRules=no
`
)
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
log.Debug("systemd-networkd not in use, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)
}
return nil
}
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
func cleanupSystemdNetworkd() error {
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
return nil
}
if err := os.Remove(networkdConfFile); err != nil {
return fmt.Errorf("remove networkd configuration: %w", err)
}
return nil
}

View File

@@ -3,125 +3,809 @@ package cmd
import (
"context"
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
var (
port int
userName = "root"
host string
const (
sshUsernameDesc = "SSH username"
hostArgumentRequired = "host argument required"
serverSSHAllowedFlag = "allow-server-ssh"
enableSSHRootFlag = "enable-ssh-root"
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
var (
port int
username string
host string
command string
localForwards []string
remoteForwards []string
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
requestPTY bool
)
split := strings.Split(args[0], "@")
if len(split) == 2 {
userName = split[0]
host = split[1]
} else {
host = args[0]
}
var (
serverSSHAllowed bool
enableSSHRoot bool
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
)
return nil
},
Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
func init() {
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
cmd.SetOut(cmd.OutOrStdout())
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
}
ctx := internal.CtxInitState(cmd.Context())
sm := profilemanager.NewServiceManager(configPath)
activeProf, err := sm.GetActiveProfileState()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
profPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile path: %v", err)
}
config, err := profilemanager.ReadConfig(profPath)
if err != nil {
return fmt.Errorf("read profile config: %v", err)
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
os.Exit(1)
}
cancel()
}()
select {
case <-sig:
cancel()
case <-sshctx.Done():
}
return nil
},
sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
var sshCmd = &cobra.Command{
Use: "ssh [flags] [user@]host [command]",
Short: "Connect to a NetBird peer via SSH",
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
Port Forwarding:
-L [bind_address:]port:host:hostport Local port forwarding
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
-R [bind_address:]port:host:hostport Remote port forwarding
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
SSH Options:
-p, --port int Remote SSH port (default 22)
-u, --user string SSH username
--login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file
Examples:
netbird ssh peer-hostname
netbird ssh root@peer-hostname
netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true,
Args: validateSSHArgsWithoutFlagParsing,
RunE: sshFn,
Aliases: []string{"ssh"},
}
func sshFn(cmd *cobra.Command, args []string) error {
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return cmd.Help()
}
}
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err
}
cancel()
}()
err = c.OpenTerminal()
if err != nil {
select {
case <-sig:
cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done():
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
func getEnvOrDefault(flagName, defaultValue string) string {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
return envValue
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
return envValue
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() {
port = sshserver.DefaultSSHPort
username = ""
host = ""
command = ""
localForwards = nil
remoteForwards = nil
strictHostKeyChecking = true
knownHostsFile = ""
identityFile = ""
}
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
var localForwardFlags []string
var remoteForwardFlags []string
var filteredArgs []string
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case strings.HasPrefix(arg, "-L"):
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
case strings.HasPrefix(arg, "-R"):
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
default:
filteredArgs = append(filteredArgs, arg)
}
}
return filteredArgs, localForwardFlags, remoteForwardFlags
}
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
if arg == "-L" || arg == "-R" {
if i+1 < len(args) {
flags = append(flags, args[i+1])
i++
}
} else if len(arg) > 2 {
flags = append(flags, arg[2:])
}
return flags, i
}
// extractGlobalFlags parses global flags that were passed before 'ssh' command
func extractGlobalFlags(args []string) {
sshPos := findSSHCommandPosition(args)
if sshPos == -1 {
return
}
globalArgs := args[:sshPos]
parseGlobalArgs(globalArgs)
}
// findSSHCommandPosition locates the 'ssh' command in the argument list
func findSSHCommandPosition(args []string) int {
for i, arg := range args {
if arg == "ssh" {
return i
}
}
return -1
}
const (
configFlag = "config"
logLevelFlag = "log-level"
logFileFlag = "log-file"
)
// parseGlobalArgs processes the global arguments and sets the corresponding variables
func parseGlobalArgs(globalArgs []string) {
flagHandlers := map[string]func(string){
configFlag: func(value string) { configPath = value },
logLevelFlag: func(value string) { logLevel = value },
logFileFlag: func(value string) {
if !slices.Contains(logFiles, value) {
logFiles = append(logFiles, value)
}
},
}
shortFlags := map[string]string{
"c": configFlag,
"l": logLevelFlag,
}
for i := 0; i < len(globalArgs); i++ {
arg := globalArgs[i]
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
i = nextIndex
}
}
}
// parseFlag handles generic flag parsing for both long and short forms
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex
}
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex + 1
}
return false, currentIndex
}
type parsedFlag struct {
flagName string
value string
}
// parseEqualsFormat handles --flag=value and -f=value formats
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if !strings.Contains(arg, "=") {
return parsedFlag{}, false
}
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return parsedFlag{}, false
}
if strings.HasPrefix(parts[0], "--") {
flagName := strings.TrimPrefix(parts[0], "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: parts[1]}, true
}
}
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
shortFlag := strings.TrimPrefix(parts[0], "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: parts[1]}, true
}
}
}
return parsedFlag{}, false
}
// parseSpacedFormat handles --flag value and -f value formats
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if currentIndex+1 >= len(args) {
return parsedFlag{}, false
}
if strings.HasPrefix(arg, "--") {
flagName := strings.TrimPrefix(arg, "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
}
}
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
shortFlag := strings.TrimPrefix(arg, "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
}
}
}
return parsedFlag{}, false
}
// createSSHFlagSet creates and configures the flag set for SSH command parsing
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
RequestPTY bool
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags
}
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
resetSSHGlobals()
if len(os.Args) > 2 {
extractGlobalFlags(os.Args[1:])
}
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err
}
remaining := fs.Args()
if len(remaining) < 1 {
return errors.New(hostArgumentRequired)
}
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}
localForwards = localForwardFlags
remoteForwards = remoteForwardFlags
return parseHostnameAndCommand(remaining)
}
func parseHostnameAndCommand(args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
arg := args[0]
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if username == "" {
username = parts[0]
}
host = parts[1]
} else {
host = arg
}
if username == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
username = sudoUser
} else if currentUser, err := user.Current(); err == nil {
username = currentUser.Username
} else {
username = "root"
}
}
// Everything after hostname becomes the command
if len(args) > 1 {
command = strings.Join(args[1:], " ")
}
return nil
}
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
})
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
return fmt.Errorf("dial %s: %w", target, err)
}
sshCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-sshCtx.Done()
if err := c.Close(); err != nil {
cmd.Printf("Error closing SSH connection: %v\n", err)
}
}()
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
return fmt.Errorf("start port forwarding: %w", err)
}
if command != "" {
return executeSSHCommand(sshCtx, c, command)
}
return openSSHTerminal(sshCtx, c)
}
// executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err)
}
return nil
}
// openSSHTerminal opens an interactive SSH terminal.
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if err := c.OpenTerminal(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err)
}
return nil
}
// startPortForwarding starts local and remote port forwarding based on command line flags
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
for _, forward := range localForwards {
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("local port forward %s: %w", forward, err)
}
}
for _, forward := range remoteForwards {
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("remote port forward %s: %w", forward, err)
}
}
return nil
}
// parseAndStartLocalForward parses and starts a local port forward (-L)
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Local port forward error: %v\n", err)
}
}()
return nil
}
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Remote port forward error: %v\n", err)
}
}()
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {
// Support formats:
// port:host:hostport -> localhost:port -> host:hostport
// host:port:host:hostport -> host:port -> host:hostport
// [host]:port:host:hostport -> [host]:port -> host:hostport
// port:unix_socket_path -> localhost:port -> unix_socket_path
// host:port:unix_socket_path -> host:port -> unix_socket_path
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
return parseIPv6ForwardSpec(spec)
}
parts := strings.Split(spec, ":")
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
}
switch len(parts) {
case 2:
return parseTwoPartForwardSpec(parts, spec)
case 3:
return parseThreePartForwardSpec(parts)
case 4:
return parseFourPartForwardSpec(parts)
default:
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
}
}
// parseTwoPartForwardSpec handles "port:unix_socket" format.
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
if isUnixSocket(parts[1]) {
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1]
return localAddr, remoteAddr, nil
}
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
}
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
func parseThreePartForwardSpec(parts []string) (string, string, error) {
if isUnixSocket(parts[2]) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2]
return localAddr, remoteAddr, nil
}
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
func parseFourPartForwardSpec(parts []string) (string, string, error) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2] + ":" + parts[3]
return localAddr, remoteAddr, nil
}
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
func parseIPv6ForwardSpec(spec string) (string, string, error) {
idx := strings.Index(spec, "]:")
if idx == -1 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
}
ipv6Host := spec[:idx+1]
remaining := spec[idx+2:]
parts := strings.Split(remaining, ":")
if len(parts) != 3 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
}
localAddr := ipv6Host + ":" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// isUnixSocket checks if a path is a Unix socket path.
func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
}
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
func normalizeLocalHost(host string) string {
if host == "*" {
return "0.0.0.0"
}
return host
}
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -0,0 +1,74 @@
//go:build unix
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sshExecUID uint32
sshExecGID uint32
sshExecGroups []uint
sshExecWorkingDir string
sshExecShell string
sshExecCommand string
sshExecPTY bool
)
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
var sshExecCmd = &cobra.Command{
Use: "exec",
Short: "Internal SSH execution with privilege dropping (hidden)",
Hidden: true,
RunE: runSSHExec,
}
func init() {
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
os.Exit(1)
}
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
os.Exit(1)
}
sshCmd.AddCommand(sshExecCmd)
}
// runSSHExec handles the SSH exec subcommand execution.
func runSSHExec(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sshExecGroups {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sshExecUID,
GID: sshExecGID,
Groups: groups,
WorkingDir: sshExecWorkingDir,
Shell: sshExecShell,
Command: sshExecCommand,
PTY: sshExecPTY,
}
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build unix
package cmd
import (
"errors"
"io"
"os"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpUID uint32
sftpGID uint32
sftpGroupsInt []uint
sftpWorkingDir string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with privilege dropping (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sftpGroupsInt {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sftpUID,
GID: sftpGID,
Groups: groups,
WorkingDir: sftpWorkingDir,
Shell: "",
Command: "",
}
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
cmd.PrintErrf("privilege drop failed: %v\n", err)
os.Exit(sshserver.ExitCodePrivilegeDropFail)
}
if config.WorkingDir != "" {
if err := os.Chdir(config.WorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Tracef("starting SFTP server with dropped privileges")
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeShellExecFail)
}
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeSuccess)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build windows
package cmd
import (
"errors"
"fmt"
"io"
"os"
"os/user"
"strings"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpWorkingDir string
windowsUsername string
windowsDomain string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with user switching for Windows (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
return sftpMainDirect(cmd)
}
func sftpMainDirect(cmd *cobra.Command) error {
currentUser, err := user.Current()
if err != nil {
cmd.PrintErrf("failed to get current user: %v\n", err)
os.Exit(sshserver.ExitCodeValidationFail)
}
if windowsUsername != "" {
expectedUsername := windowsUsername
if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
}
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail)
}
}
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
if sftpWorkingDir != "" {
if err := os.Chdir(sftpWorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Debugf("starting SFTP server")
exitCode := sshserver.ExitCodeSuccess
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
exitCode = sshserver.ExitCodeShellExecFail
}
if err := sftpServer.Close(); err != nil {
log.Debugf("SFTP server close error: %v", err)
}
os.Exit(exitCode)
return nil
}

717
client/cmd/ssh_test.go Normal file
View File

@@ -0,0 +1,717 @@
package cmd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSHCommand_FlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedUser string
expectedPort int
expectedCmd string
expectError bool
}{
{
name: "basic host",
args: []string{"hostname"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "",
},
{
name: "user@host format",
args: []string{"user@hostname"},
expectedHost: "hostname",
expectedUser: "user",
expectedPort: 22,
expectedCmd: "",
},
{
name: "host with command",
args: []string{"hostname", "echo", "hello"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "echo hello",
},
{
name: "command with flags should be preserved",
args: []string{"hostname", "ls", "-la", "/tmp"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "ls -la /tmp",
},
{
name: "double dash separator",
args: []string{"hostname", "--", "ls", "-la"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "-- ls -la",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
// Mock command for testing
cmd := sshCmd
cmd.SetArgs(tt.args)
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
if tt.expectedUser != "" {
assert.Equal(t, tt.expectedUser, username, "username mismatch")
}
assert.Equal(t, tt.expectedPort, port, "port mismatch")
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
})
}
}
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
// Test that SSH flags don't conflict with command flags
tests := []struct {
name string
args []string
expectedCmd string
description string
}{
{
name: "ls with -la flags",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
description: "ls flags should be passed to remote command",
},
{
name: "grep with -r flag",
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
expectedCmd: "grep -r pattern /path",
description: "grep flags should be passed to remote command",
},
{
name: "ps with aux flags",
args: []string{"hostname", "ps", "aux"},
expectedCmd: "ps aux",
description: "ps flags should be passed to remote command",
},
{
name: "command with double dash",
args: []string{"hostname", "--", "ls", "-la"},
expectedCmd: "-- ls -la",
description: "double dash should be preserved in command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
// Test that commands with arguments should execute the command and exit,
// not drop to an interactive shell
tests := []struct {
name string
args []string
expectedCmd string
shouldExit bool
description string
}{
{
name: "ls command should execute and exit",
args: []string{"hostname", "ls"},
expectedCmd: "ls",
shouldExit: true,
description: "ls command should execute and exit, not drop to shell",
},
{
name: "ls with flags should execute and exit",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
shouldExit: true,
description: "ls with flags should execute and exit, not drop to shell",
},
{
name: "pwd command should execute and exit",
args: []string{"hostname", "pwd"},
expectedCmd: "pwd",
shouldExit: true,
description: "pwd command should execute and exit, not drop to shell",
},
{
name: "echo command should execute and exit",
args: []string{"hostname", "echo", "hello"},
expectedCmd: "echo hello",
shouldExit: true,
description: "echo command should execute and exit, not drop to shell",
},
{
name: "no command should open shell",
args: []string{"hostname"},
expectedCmd: "",
shouldExit: false,
description: "no command should open interactive shell",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// When command is present, it should execute the command and exit
// When command is empty, it should open interactive shell
hasCommand := command != ""
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
})
}
}
func TestSSHCommand_FlagHandling(t *testing.T) {
// Test that flags after hostname are not parsed by netbird but passed to SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "ls with -la flag should not be parsed by netbird",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
},
{
name: "command with netbird-like flags should be passed through",
args: []string{"hostname", "echo", "--help"},
expectedHost: "hostname",
expectedCmd: "echo --help",
expectError: false,
description: "--help should be passed to echo, not parsed by netbird",
},
{
name: "command with -p flag should not conflict with SSH port flag",
args: []string{"hostname", "ps", "-p", "1234"},
expectedHost: "hostname",
expectedCmd: "ps -p 1234",
expectError: false,
description: "ps -p should be passed to ps command, not parsed as port",
},
{
name: "tar with flags should be passed through",
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
expectedHost: "hostname",
expectedCmd: "tar -czf backup.tar.gz /home",
expectError: false,
description: "tar flags should be passed to tar command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
// should not parse -la as netbird flags but pass them to the SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "original issue: ls -la should be preserved",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "The original failing case should now work",
},
{
name: "ls -l should be preserved",
args: []string{"hostname", "ls", "-l"},
expectedHost: "hostname",
expectedCmd: "ls -l",
expectError: false,
description: "Single letter flags should be preserved",
},
{
name: "SSH port flag should work",
args: []string{"-p", "2222", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedCmd: "ls -la",
expectError: false,
description: "SSH -p flag should be parsed, command flags preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// Check port for the test case with -p flag
if len(tt.args) > 0 && tt.args[0] == "-p" {
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
}
})
}
}
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectError bool
description string
}{
{
name: "local port forwarding -L",
args: []string{"-L", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Single -L flag should be parsed correctly",
},
{
name: "remote port forwarding -R",
args: []string{"-R", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectError: false,
description: "Single -R flag should be parsed correctly",
},
{
name: "multiple local port forwards",
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
expectedRemote: []string{},
expectError: false,
description: "Multiple -L flags should be parsed correctly",
},
{
name: "multiple remote port forwards",
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
expectError: false,
description: "Multiple -R flags should be parsed correctly",
},
{
name: "mixed local and remote forwards",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectError: false,
description: "Mixed -L and -R flags should be parsed correctly",
},
{
name: "port forwarding with bind address",
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with bind address should work",
},
{
name: "port forwarding with command",
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with command should work",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
})
}
}
func TestParsePortForward(t *testing.T) {
tests := []struct {
name string
spec string
expectedLocal string
expectedRemote string
expectError bool
description string
}{
{
name: "simple port forward",
spec: "8080:localhost:80",
expectedLocal: "localhost:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Simple port:host:port format should work",
},
{
name: "port forward with bind address",
spec: "127.0.0.1:8080:localhost:80",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "bind_address:port:host:port format should work",
},
{
name: "port forward to different host",
spec: "8080:example.com:443",
expectedLocal: "localhost:8080",
expectedRemote: "example.com:443",
expectError: false,
description: "Forwarding to different host should work",
},
{
name: "port forward with IPv6 (needs bracket support)",
spec: "::1:8080:localhost:80",
expectError: true,
description: "IPv6 without brackets fails as expected (feature to implement)",
},
{
name: "invalid format - too few parts",
spec: "8080:localhost",
expectError: true,
description: "Invalid format with too few parts should fail",
},
{
name: "invalid format - too many parts",
spec: "127.0.0.1:8080:localhost:80:extra",
expectError: true,
description: "Invalid format with too many parts should fail",
},
{
name: "empty spec",
spec: "",
expectError: true,
description: "Empty spec should fail",
},
{
name: "unix socket local forward",
spec: "8080:/tmp/socket",
expectedLocal: "localhost:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket forwarding should work",
},
{
name: "unix socket with bind address",
spec: "127.0.0.1:8080:/tmp/socket",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket with bind address should work",
},
{
name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
},
{
name: "wildcard for port only",
spec: "8080:*:80",
expectedLocal: "localhost:8080",
expectedRemote: "*:80",
expectError: false,
description: "Wildcard in remote host should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
if tt.expectError {
assert.Error(t, err, tt.description)
return
}
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
})
}
}
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
// Integration test for port forwarding with the actual SSH command implementation
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectedCmd string
description string
}{
{
name: "local forward with command",
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectedCmd: "echo test",
description: "Local forwarding should work with commands",
},
{
name: "remote forward with command",
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectedCmd: "ls -la",
description: "Remote forwarding should work with commands",
},
{
name: "multiple forwards with user and command",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectedCmd: "ps aux",
description: "Complex case with multiple forwards, user, and command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
})
}
}
func TestSSHCommand_ParameterIsolation(t *testing.T) {
tests := []struct {
name string
args []string
expectedCmd string
}{
{
name: "cmd flag passed as command",
args: []string{"hostname", "--cmd", "echo test"},
expectedCmd: "--cmd echo test",
},
{
name: "uid flag passed as command",
args: []string{"hostname", "--uid", "1000"},
expectedCmd: "--uid 1000",
},
{
name: "shell flag passed as command",
args: []string{"hostname", "--shell", "/bin/bash"},
expectedCmd: "--shell /bin/bash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
require.NoError(t, err)
assert.Equal(t, "hostname", host)
assert.Equal(t, tt.expectedCmd, command)
})
}
}
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx)
resp, err := getStatus(ctx, false)
if err != nil {
return err
}
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
}
if err != nil {
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}

View File

@@ -9,34 +9,33 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func startTestingServices(t *testing.T) string {
@@ -90,21 +89,25 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
jobManager := mgmt.NewJobManager(nil, store)
jobManager := job.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT().
@@ -112,13 +115,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -63,6 +63,7 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
upCmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
@@ -184,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -196,11 +197,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
//todo: do we need to pass logFile here ?
connectClient := internal.NewConnectClient(ctx, config, r, "")
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
@@ -284,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error
var loginResp *proto.LoginResponse
@@ -346,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
@@ -358,6 +386,11 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.WireguardPort = &p
}
if cmd.Flag(mtuFlag).Changed {
m := int64(mtu)
req.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &networkMonitor
}
@@ -425,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
@@ -437,6 +494,13 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.WireguardPort = &p
}
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
ic.MTU = &mtu
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
@@ -518,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}
@@ -534,6 +623,14 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.WireguardPort = &wp
}
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
m := int64(mtu)
loginRequest.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor
}

View File

@@ -18,12 +18,16 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
var (
ErrClientAlreadyStarted = errors.New("client already started")
ErrClientNotStarted = errors.New("client not started")
ErrEngineNotStarted = errors.New("engine not started")
ErrConfigNotInitialized = errors.New("config not initialized")
)
// Client manages a netbird embedded client instance.
type Client struct {
@@ -170,15 +174,14 @@ func (c *Client) Start(startCtx context.Context) error {
recorder := peer.NewRecorder(c.config.ManagementURL.String())
//todo: do we need to pass logFile here ?
client := internal.NewConnectClient(ctx, c.config, recorder, "")
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()
@@ -240,17 +243,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
engine, err := c.getEngine()
if err != nil {
return nil, err
}
nsnet, err := engine.GetNet()
@@ -261,6 +256,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// DialContext dials a network address in the netbird network with context
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return c.Dial(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
@@ -316,18 +316,47 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
engine, err := c.getEngine()
if err != nil {
return err
}
storedKey, found := engine.GetPeerSSHKey(peerAddress)
if !found {
return sshcommon.ErrPeerNotFound
}
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.
func (c *Client) getEngine() (*internal.Engine, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
if connect == nil {
return nil, ErrClientNotStarted
}
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
return nil, ErrEngineNotStarted
}
return engine, nil
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
engine, err := c.getEngine()
if err != nil {
return nil, netip.Addr{}, err
}
addr, err := engine.Address()

View File

@@ -15,13 +15,13 @@ import (
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}

View File

@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if !iface.IsUserspaceBind() {
return fm, err
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface, mtu)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface)
return nbiptables.Create(iface, mtu)
case NFTABLES:
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface)
return nbnftables.Create(iface, mtu)
default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
}
if errUsp != nil {

View File

@@ -1,18 +1,19 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
@@ -40,19 +41,13 @@ type aclManager struct {
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
m := &aclManager{
return &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return m, nil
}, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
}}, nil
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
shouldDestroyIpset = true
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
}
m.ipsetStore.deleteIpset(ipsetName)
}
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
// don't use IP matching if IP is 0.0.0.0
if ip.IsUnspecified() {
matchByIP = false
}
@@ -400,7 +419,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
@@ -417,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -36,7 +36,7 @@ type iFaceMapper interface {
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("init iptables: %w", err)
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface)
m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -260,6 +261,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err)
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)

View File

@@ -10,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// constants needed to manage and create iptable rules
@@ -30,17 +30,20 @@ const (
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainFORWARD = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
@@ -48,6 +51,9 @@ const (
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
type ruleInfo struct {
@@ -77,16 +83,18 @@ type router struct {
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
mtu uint16
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
@@ -99,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return r, nil
}
@@ -224,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
if err := r.createIPSet(setName); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
@@ -238,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
if err := r.destroyIPSet(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
@@ -392,6 +396,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
@@ -416,6 +421,7 @@ func (r *router) createContainers() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
@@ -438,6 +444,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
@@ -518,6 +528,35 @@ func (r *router) addPostroutingRules() error {
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSCLAMP,
}
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
@@ -558,7 +597,7 @@ func (r *router) addJumpRules() error {
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
if rule, exists := r.rules[ruleKey]; exists {
var table, chain string
switch ruleKey {
@@ -571,6 +610,9 @@ func (r *router) cleanJumpRules() error {
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
case jumpMSSClamp:
table = tableMangle
chain = chainFORWARD
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}
@@ -869,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
@@ -880,6 +922,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
@@ -899,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -14,7 +14,8 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
func isIptablesSupported() bool {
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
assert.NoError(t, manager.Reset(), "shouldn't return error")
}()
// Now 5 rules:
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 7. static return masquerade rule
// 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
// 10. jump rule to MSS clamping chain
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -11,6 +12,7 @@ type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
ipt, err := Create(s.InterfaceState, mtu)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}

View File

@@ -100,6 +100,9 @@ type Manager interface {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
id []byte,
ip net.IP,
@@ -151,14 +154,20 @@ type Manager interface {
DisableRouting() error
// AddDNATRule adds a DNAT rule
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
// DeleteDNATRule deletes the outbound DNAT rule.
DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
@@ -29,8 +29,6 @@ const (
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
const flushError = "flush: %w"
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
},
)
return m.rConn.AddRule(&nftables.Rule{
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
return nil
}
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":"
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}

View File

@@ -1,11 +1,11 @@
package nftables
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
"github.com/google/nftables"
@@ -19,13 +19,22 @@ import (
)
const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird"
// envTableName is the environment variable to override the table name
envTableName = "NB_NFTABLES_TABLE"
tableNameFilter = "filter"
chainNameInput = "INPUT"
)
func getTableName() string {
if name := os.Getenv(envTableName); name != "" {
return name
}
return tableNameNetbird
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
@@ -44,16 +53,16 @@ type Manager struct {
}
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface)
m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
err := m.aclManager.createDefaultAllowRules()
if err != nil {
return fmt.Errorf("failed to create default allow rules: %v", err)
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
return nil
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
}
tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
@@ -376,61 +315,40 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))

View File

@@ -16,13 +16,14 @@ import (
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
@@ -32,12 +33,17 @@ const (
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const refreshRulesMapError = "refresh rules map: %w"
@@ -63,9 +69,10 @@ type router struct {
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
return nil
}
// Reset cleans existing nftables default forward rules from the system
// Reset cleans existing nftables filter table rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
return r.acceptFilterRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
return r.acceptFilterRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
log.Debugf("added iptables forward rule: %v", rule)
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
}
}
func (r *router) acceptForwardRulesNftables() error {
func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
func (r *router) removeAcceptFilterRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
return r.removeAcceptFilterRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}
@@ -1350,6 +1490,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
)
const (
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))

View File

@@ -3,6 +3,7 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -10,6 +11,7 @@ type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
nft, err := Create(s.InterfaceState, mtu)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}

View File

@@ -22,6 +22,8 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
}
// these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists {
t.updateState(key, conn, flags, direction, size)
return key, true
return key, uint16(conn.DNATOrigPort.Load()), true
}
return key, false
return key, 0, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 {
return
}
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
t.logger.Trace2("New %s TCP connection: %s", direction, key)
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
}
}
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
key := ConnKey{
SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
if exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
return key, uint16(conn.DNATOrigPort.Load()), true
}
return key, false
return key, 0, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort,
DestPort: dstPort,
}
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace2("New %s UDP connection: %s", direction, key)
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net"
@@ -27,7 +28,18 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const layerTypeAll = 0
const (
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
@@ -36,6 +48,9 @@ const (
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
@@ -109,6 +124,17 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValue uint16
mssClampEnabled bool
}
// decoder for packages
@@ -122,19 +148,21 @@ type decoder struct {
icmp6 layers.ICMPv6
decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger)
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
@@ -142,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}
func parseCreateEnv() (bool, bool) {
var disableConntrack, enableLocalForwarding bool
func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
@@ -162,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
}
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
disableMSSClamping, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
}
}
return disableConntrack, enableLocalForwarding
return disableConntrack, enableLocalForwarding, disableMSSClamping
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{
decoders: sync.Pool{
@@ -196,13 +230,19 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
mtu: mtu,
}
m.routingEnabled.Store(false)
if !disableMSSClamping {
m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
}
if disableConntrack {
log.Info("conntrack is disabled")
} else {
@@ -210,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
}
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
@@ -320,7 +357,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
if err != nil {
m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err)
@@ -626,11 +663,20 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return false
}
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
switch d.decoded[1] {
case layers.LayerTypeUDP:
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
}
case layers.LayerTypeTCP:
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
m.clampTCPMSS(packetData, d)
}
}
m.trackOutbound(d, srcIP, dstIP, size)
m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.translateOutboundDNAT(packetData, d)
return false
@@ -674,14 +720,117 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
if !d.tcp.SYN {
return false
}
if len(d.tcp.Options) == 0 {
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue {
mssOptionIndex = i
break
}
}
}
if mssOptionIndex == -1 {
return false
}
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
optOffset := tcpOptionsStart
for j := 0; j < mssOptionIndex; j++ {
switch d.tcp.Options[j].OptionType {
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
optOffset++
default:
optOffset += 2 + len(d.tcp.Options[j].OptionData)
}
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
}
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
tcpLayer[16] = 0
tcpLayer[17] = 0
var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
if tcpLength%2 == 1 {
sum += uint32(tcpLayer[tcpLength-1]) << 8
}
for sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16)
}
checksum := ^uint16(sum)
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
}
@@ -691,13 +840,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
}
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
@@ -759,10 +910,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
@@ -807,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true
}
// If requested we pass local traffic to internal interfaces to the forwarder.
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
if m.shouldForward(d, dstIP) {
return m.handleForwardedLocalTraffic(packetData)
}
@@ -1243,3 +1402,86 @@ func (m *Manager) DisableRouting() error {
return nil
}
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
m.netstackServices[key] = struct{}{}
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
}
// UnregisterNetstackService removes a service from the netstack registry
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
delete(m.netstackServices, key)
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
}
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
switch protocol {
case nftypes.TCP:
return layers.LayerTypeTCP
case nftypes.UDP:
return layers.LayerTypeUDP
case nftypes.ICMP:
return layers.LayerTypeICMPv4
default:
return gopacket.LayerType(0) // Invalid/unknown
}
}
// shouldForward determines if a packet should be forwarded to the forwarder.
// The forwarder handles routing packets to the native OS network stack.
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
// not enabled, never forward
if !m.localForwarding {
return false
}
// netstack always needs to forward because it's lacking a native interface
// exception for registered netstack services, those should go to netstack listeners
if m.netstack {
return !m.hasMatchingNetstackService(d)
}
// traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP {
return true
}
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
return false
}
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
if len(d.decoded) < 2 {
return false
}
var dstPort uint16
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
dstPort = uint16(d.udp.DstPort)
default:
return false
}
key := serviceKey{protocol: d.decoded[1], port: dstPort}
m.netstackServiceMutex.RLock()
_, exists := m.netstackServices[key]
m.netstackServiceMutex.RUnlock()
return exists
}

View File

@@ -17,6 +17,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
}
}
}
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
// This shows the overhead difference between the common case (non-SYN packets, fast path)
// and the rare case (SYN packets that need clamping, expensive path).
func BenchmarkMSSClamping(b *testing.B) {
scenarios := []struct {
name string
description string
genPacket func(*testing.B, net.IP, net.IP) []byte
frequency string
}{
{
name: "syn_needs_clamp",
description: "SYN packet needing MSS clamping",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
frequency: "~0.1% of traffic - EXPENSIVE",
},
{
name: "syn_no_clamp_needed",
description: "SYN packet with already-small MSS",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
},
frequency: "~0.05% of traffic",
},
{
name: "tcp_ack",
description: "Non-SYN TCP packet (ACK, data transfer)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
frequency: "~60-70% of traffic - FAST PATH",
},
{
name: "tcp_psh_ack",
description: "TCP data packet (PSH+ACK)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
frequency: "~10-20% of traffic - FAST PATH",
},
{
name: "udp",
description: "UDP packet",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
frequency: "~20-30% of traffic - FAST PATH",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
// for the common case (non-SYN TCP packets).
func BenchmarkMSSClampingOverhead(b *testing.B) {
scenarios := []struct {
name string
enabled bool
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "disabled_tcp_ack",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "enabled_tcp_ack",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "disabled_syn_needs_clamp",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "enabled_syn_needs_clamp",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValue = 1240
}
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
func BenchmarkMSSClampingMemory(b *testing.B) {
scenarios := []struct {
name string
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "tcp_ack_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "syn_needs_clamp",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "udp_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
b.Helper()
ip := &layers.IPv4{
Version: 4,
IHL: 5,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Seq: 1000,
Window: 65535,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
return buf.Bytes()
}

View File

@@ -12,6 +12,7 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NotNil(t, manager)
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -17,9 +18,11 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -66,7 +69,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -86,7 +89,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -119,7 +122,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -215,7 +218,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -265,7 +268,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -304,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -367,7 +370,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, false, flowLogger)
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@@ -413,7 +416,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close()
@@ -495,7 +498,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -522,7 +525,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker
@@ -729,7 +732,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -815,7 +818,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -923,3 +926,327 @@ func TestUpdateSetDeduplication(t *testing.T) {
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}
func TestMSSClamping(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
srcIP := net.ParseIP("100.10.0.2")
dstIP := net.ParseIP("8.8.8.8")
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
lowMSS := uint16(1200)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
})
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
})
}
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
ACK: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Window: 65535,
}
if flags&uint16(conntrack.TCPSyn) != 0 {
tcpLayer.SYN = true
}
if flags&uint16(conntrack.TCPAck) != 0 {
tcpLayer.ACK = true
}
if flags&uint16(conntrack.TCPFin) != 0 {
tcpLayer.FIN = true
}
if flags&uint16(conntrack.TCPRst) != 0 {
tcpLayer.RST = true
}
if flags&uint16(conntrack.TCPPush) != 0 {
tcpLayer.PSH = true
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func TestShouldForward(t *testing.T) {
// Set up test addresses
wgIP := netip.MustParseAddr("100.10.0.1")
otherIP := netip.MustParseAddr("100.10.0.2")
// Create test manager with mock interface
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// Set the mock to return our test WG IP
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Helper to create decoder with TCP packet
createTCPDecoder := func(dstPort uint16) *decoder {
ipv4 := &layers.IPv4{
Version: 4,
Protocol: layers.IPProtocolTCP,
SrcIP: net.ParseIP("192.168.1.100"),
DstIP: wgIP.AsSlice(),
}
tcp := &layers.TCP{
SrcPort: 54321,
DstPort: layers.TCPPort(dstPort),
}
err := tcp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
require.NoError(t, err)
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
require.NoError(t, err)
return d
}
tests := []struct {
name string
localForwarding bool
netstack bool
dstIP netip.Addr
serviceRegistered bool
servicePort uint16
expected bool
description string
}{
{
name: "no local forwarding",
localForwarding: false,
netstack: true,
dstIP: wgIP,
expected: false,
description: "should never forward when local forwarding disabled",
},
{
name: "traffic to other local interface",
localForwarding: true,
netstack: false,
dstIP: otherIP,
expected: true,
description: "should forward traffic to our other local interfaces (not NetBird IP)",
},
{
name: "traffic to NetBird IP, no netstack",
localForwarding: true,
netstack: false,
dstIP: wgIP,
expected: false,
description: "should send to netstack listeners (final return false path)",
},
{
name: "traffic to our IP, netstack mode, no service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
expected: true,
description: "should forward when in netstack mode with no matching service",
},
{
name: "traffic to our IP, netstack mode, with service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
serviceRegistered: true,
servicePort: 22,
expected: false,
description: "should send to netstack listeners when service is registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Configure manager
manager.localForwarding = tt.localForwarding
manager.netstack = tt.netstack
// Register service if needed
if tt.serviceRegistered {
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
}
// Create decoder for the test
decoder := createTCPDecoder(tt.servicePort)
if !tt.serviceRegistered {
decoder = createTCPDecoder(8080) // Use non-registered port
}
// Test the method
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
}

View File

@@ -45,7 +45,7 @@ type Forwarder struct {
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
return nil, fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{

View File

@@ -49,7 +49,7 @@ type idleConn struct {
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,

View File

@@ -50,6 +50,8 @@ type logMessage struct {
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
}
}
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
}
}
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
}
*buf = append(*buf, formatted...)
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done:
return nil
}
}
}

View File

@@ -5,7 +5,9 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -13,6 +15,21 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum)
}
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum)
}
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
}
}
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
}
}
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
if !originalAddr.IsValid() {
return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
// getDNATTranslation returns the translated address if a mapping exists.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
// findReverseDNATMapping finds original address for return traffic.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
// translateOutboundDNAT applies DNAT translation to outbound packets.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err)
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
return false
}
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
// translateInboundReverse applies reverse DNAT to inbound return traffic.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err)
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
return false
}
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newIPBytes := newIP.As4()
copy(packetData[16:20], newDst[:])
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
return nil
}
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
// incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
// DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -414,3 +415,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
}
})
}
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -0,0 +1,85 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestPortDNATBasic tests basic port DNAT functionality
func TestPortDNATBasic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rule for peer B (the target)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
d := parsePacket(t, packetAtoB)
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
// Verify port was translated to 22022
d = parsePacket(t, packetAtoB)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
// (prevents double NAT - original port stored in conntrack)
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
d2 := parsePacket(t, returnPacket)
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
}
// TestPortDNATMultipleRules tests multiple port DNAT rules
func TestPortDNATMultipleRules(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rules for both peers
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Test traffic to peer B gets translated
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
d1 := parsePacket(t, packetToB)
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
require.True(t, translatedToB, "Traffic to peer B should be translated")
d1 = parsePacket(t, packetToB)
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
// Test traffic to peer A gets translated
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
d2 := parsePacket(t, packetToA)
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
require.True(t, translatedToA, "Traffic to peer A should be translated")
d2 = parsePacket(t, packetToA)
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -15,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -99,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -143,3 +145,111 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,25 +16,33 @@ type PacketStage int
const (
StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageReceived: "Received",
StageInboundPortDNAT: "Inbound Port DNAT",
StageInbound1to1NAT: "Inbound 1:1 NAT",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s]
}
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
}
return trace
}
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -10,6 +10,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
if !statefulMode {
@@ -104,6 +105,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -126,6 +129,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -153,6 +158,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -179,6 +186,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -204,6 +213,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -228,6 +239,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -246,6 +259,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -264,6 +279,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageCompleted,
@@ -287,6 +304,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageCompleted,
},
@@ -301,6 +320,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: true,
@@ -319,6 +340,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -340,6 +363,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -362,6 +387,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -382,6 +409,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -406,6 +435,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting,
StagePeerACL,
StageCompleted,

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"runtime"
"time"
@@ -29,7 +30,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -37,13 +39,11 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
RootCAs: certPool,
}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
@@ -58,8 +58,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
return nil, fmt.Errorf("dial context: %w", err)
}
return conn, nil

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil

View File

@@ -3,7 +3,7 @@ package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -1,5 +1,17 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -0,0 +1,9 @@
package bufsize
const (
// WGBufferOverhead represents the additional buffer space needed beyond MTU
// for WireGuard packet encapsulation (WG header + UDP + IP + safety margin)
// Original hardcoded buffers were 1500, default MTU is 1280, so overhead = 220
// TODO: Calculate this properly based on actual protocol overhead instead of using hardcoded difference
WGBufferOverhead = 220
)

View File

@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil
}
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil
}
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -394,6 +455,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil
}
@@ -402,7 +470,7 @@ func toBytes(s string) (int64, error) {
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark
}
return 0

View File

@@ -7,19 +7,21 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -21,7 +22,7 @@ type WGTunDevice struct {
address wgaddr.Address
port int
key string
mtu int
mtu uint16
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
@@ -29,11 +30,11 @@ type WGTunDevice struct {
name string
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -58,7 +59,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
return nil, err
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
}
return t.configurer, nil
}
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -137,6 +138,10 @@ func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *WGTunDevice) MTU() uint16 {
return t.mtu
}
func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
@@ -145,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -21,16 +22,16 @@ type TunDevice struct {
address wgaddr.Address
port int
key string
mtu int
mtu uint16
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -42,7 +43,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
tunDevice, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -111,6 +112,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string {
return t.name
}
@@ -149,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -22,21 +23,23 @@ type TunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
tunFd int
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunFd: tunFd,
}
@@ -81,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -125,6 +128,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement
return nil
@@ -137,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -24,19 +24,19 @@ type TunKernelDevice struct {
address wgaddr.Address
wgPort int
key string
mtu int
mtu uint16
ctx context.Context
ctxCancel context.CancelFunc
transportNet transport.Net
link *wgLink
udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
filterFn bind.FilterFn
filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -66,7 +66,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
// TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
if err := link.setMTU(t.mtu); err != nil {
if err := link.setMTU(int(t.mtu)); err != nil {
return nil, fmt.Errorf("set mtu: %w", err)
}
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil
}
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -96,23 +96,19 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter())
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter(), t.mtu)
if err != nil {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux
@@ -158,6 +154,10 @@ func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunKernelDevice) MTU() uint16 {
return t.mtu
}
func (t *TunKernelDevice) DeviceName() string {
return t.name
}
@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -1,6 +1,3 @@
//go:build !android
// +build !android
package device
import (
@@ -15,14 +12,16 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
EndpointManager
}
type TunNetstackDevice struct {
@@ -30,14 +29,14 @@ type TunNetstackDevice struct {
address wgaddr.Address
port int
key string
mtu int
mtu uint16
listenAddress string
bind Bind
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
@@ -55,7 +54,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
}
}
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
func (t *TunNetstackDevice) create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
@@ -65,7 +64,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
}
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil {
@@ -91,7 +90,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -137,6 +136,10 @@ func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunNetstackDevice) MTU() uint16 {
return t.mtu
}
func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
@@ -153,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -0,0 +1,7 @@
//go:build android
package device
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}

View File

@@ -0,0 +1,7 @@
//go:build !android
package device
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
return t.create()
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -20,16 +21,16 @@ type USPDevice struct {
address wgaddr.Address
port int
key string
mtu int
mtu uint16
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
@@ -44,9 +45,9 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, int(t.mtu), err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -118,6 +119,10 @@ func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *USPDevice) MTU() uint16 {
return t.mtu
}
func (t *USPDevice) DeviceName() string {
return t.name
}
@@ -141,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -23,17 +24,17 @@ type TunDevice struct {
address wgaddr.Address
port int
key string
mtu int
mtu uint16
iceBind *bind.ICEBind
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -59,7 +60,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, err
}
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, int(t.mtu))
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -144,6 +145,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string {
return t.name
}
@@ -180,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -0,0 +1,13 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,4 +21,5 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
}

View File

@@ -5,19 +5,21 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -26,6 +26,8 @@ import (
const (
DefaultMTU = 1280
MinMTU = 576
MaxMTU = 8192
DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault
)
@@ -35,6 +37,17 @@ var (
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
// ValidateMTU validates that MTU is within acceptable range
func ValidateMTU(mtu uint16) error {
if mtu < MinMTU {
return fmt.Errorf("MTU %d below minimum (%d bytes)", mtu, MinMTU)
}
if mtu > MaxMTU {
return fmt.Errorf("MTU %d exceeds maximum supported size (%d bytes)", mtu, MaxMTU)
}
return nil
}
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -45,10 +58,10 @@ type WGIFaceOpts struct {
Address string
WGPort int
WGPrivKey string
MTU int
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
FilterFn udpmux.FilterFn
DisableDNS bool
}
@@ -67,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
@@ -82,6 +106,10 @@ func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress()
}
func (w *WGIface) MTU() uint16 {
return w.tun.MTU()
}
// ToInterface returns the net.Interface for the Wireguard interface
func (r *WGIface) ToInterface() *net.Interface {
name := r.tun.DeviceName()
@@ -97,7 +125,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -131,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()

View File

@@ -17,7 +17,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -16,10 +16,10 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build linux && !android
package iface
@@ -22,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
@@ -31,11 +31,11 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)

View File

@@ -14,7 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -1,6 +1,7 @@
package iface
import (
"context"
"fmt"
"net"
"net/netip"
@@ -9,13 +10,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// keep darwin compatibility
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
peer2wgPort := 33200
keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet()
newNet, err = stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
Mux *SingleSocketUDPMux
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
CandidateID string
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err
}
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:

View File

@@ -0,0 +1,64 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,6 +1,7 @@
package bind
package udpmux
import (
"context"
"fmt"
"io"
"net"
@@ -8,12 +9,13 @@ import (
"strings"
"sync"
"github.com/pion/ice/v3"
"github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
/*
@@ -22,9 +24,9 @@ import (
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
// SingleSocketUDPMux is an implementation of the interface
type SingleSocketUDPMux struct {
params Params
closedChan chan struct{}
closeOnce sync.Once
@@ -32,6 +34,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
@@ -46,8 +51,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
// Params are parameters for UDPMux.
type Params struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
@@ -147,18 +152,19 @@ func isZeros(ip net.IP) bool {
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
// NewSingleSocketUDPMux creates an implementation of UDPMux
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil {
params.Logger = getLogger()
}
mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
@@ -171,15 +177,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
@@ -195,7 +201,7 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}
@@ -216,13 +222,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock()
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
// LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
@@ -236,7 +242,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +266,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil
}
c := m.createMuxedConn(ufrag)
c := m.createMuxedConn(ufrag, candidateID)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
m.candidateConnMap[candidateID] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
@@ -276,7 +284,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +292,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
m.mu.Unlock()
@@ -314,7 +324,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
func (m *SingleSocketUDPMux) IsClosed() bool {
select {
case <-m.closedChan:
return true
@@ -324,7 +334,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
func (m *SingleSocketUDPMux) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
@@ -347,11 +357,11 @@ func (m *UDPMuxDefault) Close() error {
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
@@ -368,81 +378,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
CandidateID: candidateID,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.RLock()
// Try to route to specific candidate connection first
if conn := m.findCandidateConnection(msg); conn != nil {
return conn.writePacket(msg.Raw, remoteAddr)
}
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
// Forward to all found connections
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
// findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
@@ -451,6 +489,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return
}
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct {
buf []byte
}

View File

@@ -1,12 +1,12 @@
//go:build !ios
package bind
package udpmux
import (
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -15,9 +15,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -28,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
*SingleSocketUDPMux
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -44,6 +45,7 @@ type UniversalUDPMuxParams struct {
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
MTU uint16
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -70,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
udpMuxParams := UDPMuxParams{
udpMuxParams := Params{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m
}
@@ -84,7 +86,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// just ignore other packets printing an warning message.
// It is a blocking method, consider running in a go routine.
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
buf := make([]byte, 1500)
buf := make([]byte, m.params.MTU+bufsize.WGBufferOverhead)
for {
select {
case <-ctx.Done():
@@ -209,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -231,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
@@ -33,9 +34,9 @@ type ProxyBind struct {
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
mtu uint16
@@ -43,7 +44,7 @@ type ProxyBind struct {
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
p := &ProxyBind{
Bind: bind,
bind: bind,
closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
@@ -55,25 +56,25 @@ func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
if err != nil {
return err
}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -85,17 +86,21 @@ func (p *ProxyBind) Work() {
return
}
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) Pause() {
@@ -103,9 +108,25 @@ func (p *ProxyBind) Pause() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error {
@@ -116,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock()
defer p.closeMu.Unlock()
@@ -129,7 +154,12 @@ func (p *ProxyBind) close() error {
p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
@@ -156,10 +186,9 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])

View File

@@ -6,9 +6,7 @@ import (
"context"
"fmt"
"net"
"os"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -17,18 +15,25 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
mtu uint16
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
@@ -43,10 +48,11 @@ type WGEBPFProxy struct {
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
mtu: mtu,
ebpfManager: ebpf.GetEbpfManagerInstance(),
turnConnStore: make(map[uint16]net.Conn),
}
@@ -61,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
return err
}
p.rawConn, err = p.prepareSenderRawSocket()
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil {
return err
}
@@ -138,7 +144,7 @@ func (p *WGEBPFProxy) Free() error {
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil {
@@ -211,57 +217,17 @@ generatePort:
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localhost,
SrcIP: localhost,
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(port),
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
@@ -276,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil

View File

@@ -7,7 +7,7 @@ import (
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
wgProxy := NewWGEBPFProxy(1, 1280)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -27,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
wgProxy := NewWGEBPFProxy(1, 1280)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -43,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(1)
wgProxy := NewWGEBPFProxy(1, 1280)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)

View File

@@ -12,46 +12,48 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr
p.wgRelayedEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
return p.wgRelayedEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -63,14 +65,18 @@ func (p *ProxyWrapper) Work() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) Pause() {
@@ -79,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
p.cancel()
e.closeListener.SetCloseListener(nil)
p.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("close remote conn: %w", err)
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, 1500)
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock()
if err != nil {
if ctx.Err() != nil {
@@ -136,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
}
return 0, err
}

View File

@@ -11,16 +11,18 @@ import (
type KernelFactory struct {
wgPort int
mtu uint16
ebpfProxy *ebpf.WGEBPFProxy
}
func NewKernelFactory(wgPort int) *KernelFactory {
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
@@ -33,11 +35,10 @@ func NewKernelFactory(wgPort int) *KernelFactory {
func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort)
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
func (w *KernelFactory) Free() error {

View File

@@ -1,29 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

Some files were not shown because too many files have changed in this diff Show More