Compare commits

..

266 Commits

Author SHA1 Message Date
Viktor Liu
7538787fb1 Merge remote-tracking branch 'origin/client-ipv6-iptables' into client-ipv6-android-ui 2026-04-07 18:45:22 +02:00
Viktor Liu
ce5852b9dd Merge remote-tracking branch 'origin/client-ipv6-nftables' into client-ipv6-iptables 2026-04-07 18:45:13 +02:00
Viktor Liu
935ca81a8a Merge remote-tracking branch 'origin/client-ipv6-routing' into client-ipv6-nftables 2026-04-07 18:45:05 +02:00
Viktor Liu
ff1838b651 Merge remote-tracking branch 'origin/client-ipv6-acl-usp' into client-ipv6-routing 2026-04-07 18:44:57 +02:00
Viktor Liu
01068dd03e Merge remote-tracking branch 'origin/proto-ipv6-overlay' into client-ipv6-acl-usp 2026-04-07 18:44:49 +02:00
Viktor Liu
90c5065c66 Add missing SetInterfaceIPv6 to Android noop network listener 2026-04-07 18:44:27 +02:00
Viktor Liu
7dff0e2f70 Keep successful DNS answers when only one address family query fails 2026-04-07 18:40:09 +02:00
Viktor Liu
07c7d2698a Merge remote-tracking branch 'origin/client-ipv6-iptables' into client-ipv6-android-ui 2026-04-07 18:37:12 +02:00
Viktor Liu
70530081f2 Merge remote-tracking branch 'origin/client-ipv6-nftables' into client-ipv6-iptables 2026-04-07 18:36:53 +02:00
Viktor Liu
ea451dfdc6 Merge remote-tracking branch 'origin/client-ipv6-routing' into client-ipv6-nftables 2026-04-07 18:36:22 +02:00
Viktor Liu
a3925c9a89 Merge remote-tracking branch 'origin/client-ipv6-acl-usp' into client-ipv6-routing 2026-04-07 18:36:00 +02:00
Viktor Liu
974ea1fb09 Merge remote-tracking branch 'origin/proto-ipv6-overlay' into client-ipv6-acl-usp 2026-04-07 18:35:38 +02:00
Viktor Liu
9592de1aac Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay
# Conflicts:
#	client/android/client.go
#	client/ssh/server/server.go
#	shared/management/proto/management.pb.go
2026-04-07 18:35:13 +02:00
Viktor Liu
a5d4df009a Merge pull request #5675 from netbirdio/client-ipv6-iface
[client] Add IPv6 overlay address support to WireGuard interface and engine
2026-04-08 00:30:54 +08:00
Viktor Liu
ed646f5485 Merge pull request #5686 from netbirdio/client-ipv6-dns
[client] Add IPv6 reverse DNS and host configurator support
2026-04-08 00:30:47 +08:00
Viktor Liu
daeb90cf98 Merge pull request #5687 from netbirdio/client-ipv6-ssh-netflow
[client] Add IPv6 support to SSH server, client config, and netflow logger
2026-04-08 00:30:36 +08:00
Viktor Liu
024cc6cba9 Split Init to reduce cognitive complexity 2026-04-07 18:22:14 +02:00
Viktor Liu
546140a2c2 Extract rollbackInit helper to deduplicate Init cleanup 2026-04-07 18:18:15 +02:00
Viktor Liu
5ccb00df2c Merge branch 'client-ipv6-nftables' into client-ipv6-iptables 2026-04-07 18:16:58 +02:00
Viktor Liu
076bb3292e Add best-effort rollback to iptables Init on partial failure 2026-04-07 18:16:50 +02:00
Viktor Liu
443d0720a3 Fix CodeRabbit findings: rollback on init failure, accumulate Close errors, expand test guards 2026-04-07 18:15:18 +02:00
Viktor Liu
1d920d700c [client] Fix SSH server Stop() deadlock when sessions are active (#5717) 2026-04-07 17:56:54 +02:00
Viktor Liu
bb85eee40a [client] Skip down interfaces in network address collection for posture checks (#5768) 2026-04-07 17:56:48 +02:00
Viktor Liu
aba5d6f0d2 [client] Error out on netbird expose when block inbound is enabled (#5818) 2026-04-07 17:55:35 +02:00
Viktor Liu
0588d2dbe1 [management] Load missing service columns in pgx account loader (#5816) 2026-04-07 14:56:56 +02:00
Pascal Fischer
14b3b77bda [management] validate permissions on groups read with name (#5749) 2026-04-07 14:13:09 +02:00
Zoltan Papp
6da34e483c [client] Fix mgmProber interface to match unexported GetServerPublicKey (#5815)
Update the mgmProber interface to use HealthCheck() instead of the
now-unexported GetServerPublicKey(), aligning with the changes in the
management client API.
2026-04-07 13:13:38 +02:00
Zoltan Papp
0efef671d7 [client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
2026-04-07 12:18:21 +02:00
Eduard Gert
435203b13b [proxy] Update proxy web packages (#5661)
* [proxy] Update package-lock.json

* Update packages
2026-04-07 10:35:09 +02:00
Maycon Santos
decb5dd3af [client] Add GetSelectedClientRoutes to route manager and update DNS route check (#5802)
- DNS resolution broke after deselecting an exit node because the route checker used all client routes (including deselected ones) to decide how to forward upstream DNS
  queries
  - Added GetSelectedClientRoutes() to the route manager that filters out deselected exit nodes, and switched the DNS route checker to use it
  - Confirmed fix via device testing: after deselecting exit node, DNS queries now correctly use a regular network socket instead of binding to the utun interface
2026-04-05 13:44:53 +02:00
Viktor Liu
28fbf96b2a [client] Fix flaky TestServiceLifecycle/Restart on FreeBSD (#5786) 2026-04-02 21:45:49 +02:00
Bethuel Mmbaga
9d1a37c644 [management,client] Revert gRPC client secret removal (#5781)
* This reverts commit e5914e4e8b

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Deprecate client secret in proto

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-04-02 18:21:00 +02:00
Viktor Liu
5bf2372c4d [management] Fix L4 service creation deadlock on single-connection databases (#5779) 2026-04-02 14:46:14 +02:00
Bethuel Mmbaga
c2c6396a04 [management] Allow updating embedded IdP user name and email (#5721) 2026-04-02 13:02:10 +03:00
Misha Bragin
aaf813fc0c Add selfhosted scaling note (#5769) 2026-04-01 19:23:39 +02:00
Vlad
d97fe84296 [management] fix race condition in the setup flow that enables creation of multiple owner users (#5754) 2026-04-01 16:25:35 +02:00
tham-le
81f45dab21 [client] Support embed.Client on Android with netstack mode (#5623)
* [client] Support embed.Client on Android with netstack mode

embed.Client.Start() calls ConnectClient.Run() which passes an empty
MobileDependency{}. On Android, the engine dereferences nil fields
(IFaceDiscover, NetworkChangeListener, DnsReadyListener) causing panics.

Provide complete no-op stubs so the engine's existing Android code
paths work unchanged — zero modifications to engine.go:

- Add androidRunOverride hook in Run() for Android-specific dispatch
- Add runOnAndroidEmbed() with complete MobileDependency (all stubs)
- Wire default stubs via init() in connect_android_default.go:
  noopIFaceDiscover, noopNetworkChangeListener, noopDnsReadyListener
- Forward logPath to c.run()

Tested: embed.Client starts on Android arm64, joins mesh via relay,
discovers peers, localhost proxy works for TCP+UDP forwarding.

* [client] Fix TestServiceParamsPath for Windows path separators

Use filepath.Join in test assertions instead of hardcoded POSIX paths
so the test passes on Windows where filepath.Join uses backslashes.
2026-04-01 16:19:34 +02:00
Zoltan Papp
d670e7382a [client] Fix ipv6 address in quic server (#5763)
* [client] Use `net.JoinHostPort` for consistency in constructing host-port pairs

* [client] Fix handling of IPv6 addresses by trimming brackets in `net.JoinHostPort`
2026-04-01 15:11:23 +02:00
Pascal Fischer
cd8c686339 [misc] add path traversal and file size protections (#5755) 2026-04-01 14:23:24 +02:00
Pascal Fischer
f5c41e3018 [misc] set permissions on env file for getting started scripts (#5761) 2026-04-01 14:13:53 +02:00
Pascal Fischer
2477f99d89 [proxy] Add pprof (#5764) 2026-04-01 14:10:41 +02:00
shuuri-labs
940f530ac2 [management] Legacy to embedded IdP migration tool (#5586) 2026-04-01 13:53:19 +02:00
Zoltan Papp
4d3e2f8ad3 Fix path join (#5762) 2026-04-01 13:21:19 +02:00
Vlad
5ae986e1c4 [management] fix panic on management reboot (#5759) 2026-04-01 12:31:30 +02:00
Bethuel Mmbaga
e5914e4e8b [management,client] Remove client secret from gRPC auth flow (#5751)
Remove client secret from gRPC auth flow. The secret was originally included to support providers like Google Workspace that don't offer a proper PKCE flow, but this is no longer necessary with the embedded IdP. Deployments using such providers should migrate to the embedded IdP instead.
2026-03-31 18:50:49 +03:00
Pascal Fischer
c238f5425f [management] proper module permission validation for posture check delete (#5742) 2026-03-31 16:43:49 +02:00
Pascal Fischer
3c3097ea74 [management] add target user account validation (#5741) 2026-03-31 16:43:16 +02:00
Maycon Santos
405c3f4003 [management] Feature/fleetdm api spec (#5597)
add fleetdm api spec
2026-03-31 14:03:34 +02:00
Viktor Liu
6553ce4cea [client] Mock management client in TestUpdateOldManagementURL to fix CI flakiness (#5703) 2026-03-31 10:49:06 +02:00
Viktor Liu
a62d472bc4 [client] Include fake IP block routes in Android TUN rebuilds (#5739) 2026-03-31 10:36:27 +02:00
Eduard Gert
434ac7f0f5 [docs] Update CONTRIBUTOR_LICENSE_AGREEMENT.md (#5131) 2026-03-31 09:31:03 +02:00
Viktor Liu
a457faa876 Extract v6 exit node pairing into shared helpers 2026-03-30 19:31:43 +02:00
Viktor Liu
95f4e90ae8 Add IPv6 dual-stack support for Android, iOS, and desktop UI 2026-03-30 18:53:04 +02:00
Akshay Ubale
7bbe71c3ac [client] Refactor Android PeerInfo to use proper ConnStatus enum type (#5644)
* Simplify Android ConnStatus API with integer constants

Replace dual field PeerInfo design with unified integer based
ConnStatus field and exported gomobile friendly constants.

Changes:
> PeerInfo.ConnStatus: changed from string to int
> Export three constants: ConnStatusIdle, ConnStatusConnecting,ConnStatusConnected (mapped to peer.ConnStatus enum values)
> Updated PeersList() to convert peer enum directly to int

Benefits:
> Simpler API surface with single ConnStatus field
> Better gomobile compatibility for cross-platform usage
> Type-safe integer constants across language boundaries

* test: add All group to setupTestAccount fixture

The setupTestAccount() test helper was missing the required "All" group,
causing "failed to get group all: no group ALL found" errors during
test execution. Add the All group with all test peers to match the
expected account structure.

Fixes the failing account and types package tests when GetGroupAll()
is called in test scenarios.
2026-03-30 17:55:01 +02:00
Viktor Liu
21ae32bb52 Add IPv6 address parameter to TunAdapter.ConfigureInterface 2026-03-30 17:43:52 +02:00
Viktor Liu
04dcaadabf [client] Persist service install parameters across reinstalls (#5732) 2026-03-30 16:25:14 +02:00
Zoltan Papp
c522506849 [client] Add Expose support to embed library (#5695)
* [client] Add Expose support to embed library

Add ability to expose local services via the NetBird reverse proxy
from embedded client code.

Introduce ExposeSession with a blocking Wait method that keeps
the session alive until the context is cancelled.

Extract ProtocolType with ParseProtocolType into the expose package
and use it across CLI and embed layers.

* Fix TestNewRequest assertion to use ProtocolType instead of int

* Add documentation for Request and KeepAlive in expose manager

* Refactor ExposeSession to pass context explicitly in Wait method

* Refactor ExposeSession Wait method to explicitly pass context

* Update client/embed/expose.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Fix build

* Update client/embed/expose.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2026-03-30 15:53:50 +02:00
Viktor Liu
0765352c99 [management] Persist proxy capabilities to database (#5720) 2026-03-30 13:03:42 +02:00
tobsec
13807f1b3d [client] Fix Exit Node submenu separator accumulation on Windows (#5691)
* client/ui: fix Exit Node submenu separator accumulation on Windows

On Windows the tray uses a background poller (every 10s) instead of
TrayOpenedCh to keep the Exit Node menu fresh. Each poll that has a
selected exit node called s.mExitNode.AddSeparator() before the
"Deselect All" item. Because AddSeparator() returns no handle the
separator was never removed in the cleanup pass of
recreateExitNodeMenu(), while every other item (exit node checkboxes
and the "Deselect All" entry) was properly tracked and removed.

After the client has been running for a while with an exit node
selected this leaves hundreds of separator lines stacked in the
submenu, filling the screen height with blank entries (#4702).

On Linux/FreeBSD this is masked because the parent mExitNode item
itself is removed and recreated each cycle, wiping all children
including orphaned separators.

Fix: replace the untracked AddSeparator() call with a regular disabled
sub-menu item that is stored in mExitNodeSeparator and removed at the
start of each recreateExitNodeMenu() call alongside mExitNodeDeselectAll.

Fixes #4702

* client/ui: extract addExitNodeDeselectAll to reduce cognitive complexity

Move the separator + deselect-all creation and its goroutine listener
out of recreateExitNodeMenu into a dedicated helper, bringing the
function's cognitive complexity back under the SonarCloud threshold.
2026-03-30 10:41:38 +02:00
Bethuel Mmbaga
c919ea149e [misc] Add missing OpenAPI definitions (#5690) 2026-03-30 11:20:17 +03:00
Pascal Fischer
be6fd119d8 [management] no events for temporary peers (#5719) 2026-03-30 10:08:02 +02:00
Viktor Liu
310b7dd89b Merge branch 'client-ipv6-nftables' into client-ipv6-iptables 2026-03-27 17:01:21 +01:00
Viktor Liu
44d16e8791 Merge branch 'client-ipv6-routing' into client-ipv6-nftables 2026-03-27 17:01:13 +01:00
Viktor Liu
a9ee25aeef Merge branch 'client-ipv6-acl-usp' into client-ipv6-routing
# Conflicts:
#	client/firewall/uspfilter/forwarder/forwarder.go
#	client/firewall/uspfilter/forwarder/icmp.go
2026-03-27 17:00:57 +01:00
Viktor Liu
ed5cfa6dc5 Fix CodeRabbit findings: fragment guard, v6 raw socket probe, v6 echo logging 2026-03-27 15:45:19 +01:00
Viktor Liu
fcf8c4b30e Handle ICMP directly in forwarder, bypassing gVisor network layer 2026-03-27 15:25:01 +01:00
Pascal Fischer
7abf730d77 [management] update to latest grpc version (#5716) 2026-03-27 15:22:23 +01:00
Viktor Liu
faef5a57b3 Add -4/-6 IP version flags to proxy debug ping, WASM ping, and SSH 2026-03-27 13:11:34 +01:00
Viktor Liu
22c4be0c34 Fix CodeRabbit findings: anonymizer test, blockLanAccess v6 source, Windows cleanup accumulation 2026-03-27 06:57:37 +01:00
Viktor Liu
b6bd2d667b Add dual-stack iptables manager with ip6tables support 2026-03-27 06:57:37 +01:00
Viktor Liu
571527c2d3 Add iptablesProto helper, filter table fallback, DNAT compat tests 2026-03-27 06:56:15 +01:00
Viktor Liu
c13f1af196 Fix ip6tables-save compat: skip ip6tables-managed tables in external chain scan 2026-03-27 06:50:05 +01:00
Pascal Fischer
ec96c5ecaf [management] Extend blackbox tests (#5699) 2026-03-26 16:59:49 +01:00
Pascal Fischer
7e1cce4b9f [management] add terminated field to service (#5700) 2026-03-26 16:59:08 +01:00
Bethuel Mmbaga
7be8752a00 [management] Add notification endpoints (#5590) 2026-03-26 18:26:33 +03:00
Viktor Liu
e4857b4d9d Add dual-stack nftables manager with IPv6 table support 2026-03-26 14:20:43 +01:00
Viktor Liu
da6f61039a Add IPv6 routing support with forwarding, fake IPs, and exit node handling 2026-03-26 14:00:10 +01:00
Viktor Liu
76414a1061 Merge branch 'client-ipv6-ssh-netflow' into client-ipv6-acl-usp 2026-03-26 13:00:20 +01:00
Viktor Liu
1cc19e7355 Merge branch 'client-ipv6-dns' into client-ipv6-ssh-netflow 2026-03-26 13:00:10 +01:00
Viktor Liu
1a6cf9dfec Merge branch 'client-ipv6-iface' into client-ipv6-dns
# Conflicts:
#	client/internal/dns/upstream_ios.go
2026-03-26 12:59:58 +01:00
Viktor Liu
6acc6a13f1 Merge branch 'proto-ipv6-overlay' into client-ipv6-iface 2026-03-26 12:56:48 +01:00
Viktor Liu
58eb519dbc Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay 2026-03-26 12:56:35 +01:00
Viktor Liu
145d82f322 [client] Replace iOS DNS IsPrivate heuristic with route manager check (#5694) 2026-03-26 18:11:05 +08:00
Viktor Liu
a8b9570700 [client] Enable RPM package signature verification in install script (#5676) 2026-03-26 09:50:43 +01:00
Viktor Liu
6ff6d84646 [client] Bump go-m1cpu to v0.2.1 to fix segfault on macOS 26 / M5 chips (#5701) 2026-03-26 09:49:02 +01:00
Viktor Liu
aebf3ceb40 Merge branch 'client-ipv6-ssh-netflow' into client-ipv6-acl-usp 2026-03-25 11:06:28 +01:00
Viktor Liu
bc6ed1a97f Merge branch 'client-ipv6-dns' into client-ipv6-ssh-netflow 2026-03-25 10:58:33 +01:00
Viktor Liu
50c0bc583b Fix connect.go lint: use SetIPv6FromCompact instead of if-else chain 2026-03-25 10:57:40 +01:00
Viktor Liu
5e1cdd7d36 Fix localip bitmap aliasing and bench test indentation 2026-03-25 10:30:04 +01:00
Viktor Liu
7fe417c6b4 Merge branch 'client-ipv6-ssh-netflow' into client-ipv6-acl-usp 2026-03-25 10:19:45 +01:00
Viktor Liu
0c2fbd5d70 Merge branch 'client-ipv6-dns' into client-ipv6-ssh-netflow 2026-03-25 10:19:42 +01:00
Viktor Liu
641e3861c1 Merge branch 'client-ipv6-iface' into client-ipv6-dns
# Conflicts:
#	client/iface/wgaddr/address.go
2026-03-25 10:18:52 +01:00
Viktor Liu
baf2c03508 Fix CodeRabbit findings: hasIPv6Changed restart loop, empty peerIPs panic, v6 validation 2026-03-25 10:18:06 +01:00
Viktor Liu
569244a59c Merge branch 'client-ipv6-ssh-netflow' into client-ipv6-acl-usp 2026-03-25 10:07:49 +01:00
Viktor Liu
5a29fa8432 Merge branch 'client-ipv6-dns' into client-ipv6-ssh-netflow 2026-03-25 10:07:46 +01:00
Viktor Liu
5fcea07181 Merge branch 'client-ipv6-iface' into client-ipv6-dns
# Conflicts:
#	client/iface/wgaddr/address.go
2026-03-25 10:07:43 +01:00
Viktor Liu
3be5a5f230 Fix CodeRabbit findings: hasIPv6Changed restart loop, empty peerIPs panic, v6 validation 2026-03-25 10:06:55 +01:00
Viktor Liu
780fd66dd7 Fix review findings: v6 block rule, PMTUD, eBPF typed IP, tracer validation 2026-03-25 09:58:28 +01:00
Viktor Liu
133f004086 Add IPv6 support to ACL manager, USP filter, and forwarder 2026-03-25 09:58:28 +01:00
Viktor Liu
d81cd5d154 Add IPv6 support to SSH server, client config, and netflow logger 2026-03-25 09:57:58 +01:00
Viktor Liu
71962f88f8 Add IPv6 reverse DNS and host configurator support 2026-03-25 09:57:26 +01:00
Viktor Liu
878dc45abf Fix govet non-constant format string in log.Warnf 2026-03-25 09:55:44 +01:00
Viktor Liu
1a7e835949 Fix CodeRabbit findings: hasIPv6Changed restart loop, empty peerIPs panic, v6 validation 2026-03-25 09:55:44 +01:00
Viktor Liu
b852ce1a99 Add IPv6 overlay address support to client interface and engine 2026-03-25 09:55:44 +01:00
Viktor Liu
013770070a Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay 2026-03-25 09:54:47 +01:00
Viktor Liu
9aaa05e8ea Replace discontinued LocalStack image with MinIO in S3 test (#5680) 2026-03-25 15:51:29 +08:00
Bethuel Mmbaga
0af5a0441f [management] Fix DNS label uniqueness check on peer rename (#5679) 2026-03-24 20:25:29 +03:00
Viktor Liu
0fc63ea0ba [management] Allow multiple header auths with same header name (#5678) 2026-03-24 16:18:21 +01:00
Bethuel Mmbaga
0b329f7881 [management] Replace JumpCloud SDK with direct HTTP calls (#5591) 2026-03-24 13:21:42 +03:00
Viktor Liu
5b85edb753 [management] Omit proxy_protocol from API response when false (#5656)
The internal Target model uses a plain bool for ProxyProtocol,
which was always serialized to the API response as false even
when not configured. Only set the API field when true so it
gets omitted via omitempty when unset.
2026-03-23 17:53:17 +01:00
Maycon Santos
17cfa5fe1e [misc] Set signing env only if not fork and set license (#5659)
* Add condition to GPG key decoding to handle pull requests

* Add license field to deb and rpm package configurations

* Add condition to GPG key decoding for external pull requests
2026-03-23 17:16:23 +01:00
Viktor Liu
2313494e0e [client] Don't abort debug for command when up/down fails (#5657) 2026-03-23 14:04:03 +01:00
Viktor Liu
fd9d430334 [client] Simplify entrypoint by running netbird up unconditionally (#5652) 2026-03-23 09:39:32 +01:00
Viktor Liu
acdf8d981a Merge branch 'main' into proto-ipv6-overlay 2026-03-22 18:34:17 +01:00
Zoltan Papp
91f0d5cefd [client] Feature/client metrics (#5512)
* Add client metrics

* Add client metrics system with OpenTelemetry and VictoriaMetrics support

Implements a comprehensive client metrics system to track peer connection
stages and performance. The system supports multiple backend implementations
(OpenTelemetry, VictoriaMetrics, and no-op) and tracks detailed connection
stage durations from creation through WireGuard handshake.

Key changes:
- Add metrics package with pluggable backend implementations
- Implement OpenTelemetry metrics backend
- Implement VictoriaMetrics metrics backend
- Add no-op metrics implementation for disabled state
- Track connection stages: creation, semaphore, signaling, connection ready, and WireGuard handshake
- Move WireGuard watcher functionality to conn.go
- Refactor engine to integrate metrics tracking
- Add metrics export endpoint in debug server

* Add signaling metrics tracking for initial and reconnection attempts

* Reset connection stage timestamps during reconnections to exclude unnecessary metrics tracking

* Delete otel lib from client

* Update unit tests

* Invoke callback on handshake success in WireGuard watcher

* Add Netbird version tracking to client metrics

Integrate Netbird version into VictoriaMetrics backend and metrics labels. Update `ClientMetrics` constructor and metric name formatting to include version information.

* Add sync duration tracking to client metrics

Introduce `RecordSyncDuration` for measuring sync message processing time. Update all metrics implementations (VictoriaMetrics, no-op) to support the new method. Refactor `ClientMetrics` to use `AgentInfo` for static agent data.

* Remove no-op metrics implementation and simplify ClientMetrics constructor

Eliminate unused `noopMetrics` and refactor `ClientMetrics` to always use the VictoriaMetrics implementation. Update associated logic to reflect these changes.

* Add total duration tracking for connection attempts

Calculate total duration for both initial connections and reconnections, accounting for different timestamp scenarios. Update `Export` method to include Prometheus HELP comments.

* Add metrics push support to VictoriaMetrics integration

* [client] anchor connection metrics to first signal received

* Remove creation_to_semaphore connection stage metric

The semaphore queuing stage (Created → SemaphoreAcquired) is no longer
tracked. Connection metrics now start from SignalingReceived. Updated
docs and Grafana dashboard accordingly.

* [client] Add remote push config for metrics with version-based eligibility

Introduce remoteconfig.Manager that fetches a remote JSON config to control
metrics push interval and restrict pushing to a specific agent version
range. When NB_METRICS_INTERVAL is set, remote config is bypassed
entirely for local override.

* [client] Add WASM-compatible NewClientMetrics implementation

Replace NewClientMetrics in metrics.go with a WASM-specific stub in metrics_js.go, returning nil for compatibility with JS builds. Simplify method usage for WASM targets.

* Add missing file

* Update default case in DeploymentType.String to return "unknown" instead of "selfhosted"

* [client] Rework metrics to use timestamped samples instead of histograms

Replace cumulative Prometheus histograms with timestamped point-in-time
samples that are pushed once and cleared. This fixes metrics for sparse
events (connections/syncs that happen once at startup) where rate() and
increase() produced incorrect or empty results.

Changes:
- Switch from VictoriaMetrics histogram library to raw Prometheus text
  format with explicit millisecond timestamps
- Reset samples after successful push (no resending stale data)
- Rename connection_to_handshake → connection_to_wg_handshake
- Add netbird_peer_connection_count metric for ICE vs Relay tracking
- Simplify dashboard: point-based scatter plots, donut pie chart
- Add maxStalenessInterval=1m to VictoriaMetrics to prevent forward-fill
- Fix deployment_type Unknown returning "selfhosted" instead of "unknown"
- Fix inverted shouldPush condition in push.go

* [client] Add InfluxDB metrics backend alongside VictoriaMetrics

Add influxdb.go with timestamped line protocol export for sparse
one-shot events. Restore victoria.go to use proper Prometheus
histograms. Update Grafana dashboards, add InfluxDB datasource,
and update docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* [client] Fix metrics issues and update dev docker setup

- Fix StopPush not clearing push state, preventing restart
- Fix race condition reading currentConnPriority without lock in recordConnectionMetrics
- Fix stale comment referencing old metrics server URL
- Update docker-compose for InfluxDB: add scoped tokens, .env config, init scripts
- Rename docker-compose.victoria.yml to docker-compose.yml

* [client] Add anonymised peer tracking to pushed metrics

Introduce peer_id and connection_pair_id tags to InfluxDB metrics.
Public keys are hashed (truncated SHA-256) for anonymisation. The
connection pair ID is deterministic regardless of which side computes
it, enabling deduplication of reconnections in the ICE vs Relay
dashboard. Also pin Grafana to v11.6.0 for file-based provisioning
and fix datasource UID references.

* Remove unused dependencies from go.mod and go.sum

* Refactor InfluxDB ingest pipeline: extract validation logic

- Move line validation logic to `validateLine` and `validateField` helper functions.
- Improve error handling with structured validation and clearer separation of concerns.
- Add stderr redirection for error messages in `create-tokens.sh`.

* Set non-root user in Dockerfile for Ingest service

* Fix Windows CI: command line too long

* Remove Victoria metrics

* Add hashed peer ID as Authorization header in metrics push

* Revert influxdb in docker compose

* Enable gzip compression and authorization validation for metrics push and ingest

* Reducate code of complexity

* Update debug documentation to include metrics.txt description

* Increase `maxBodySize` limit to 50 MB and update gzip reader wrapping logic

* Refactor deployment type detection to use URL parsing for improved accuracy

* Update readme

* Throttle remote config retries on fetch failure

* Preserve first WG handshake timestamp, ignore rekeys

* Skip adding empty metrics.txt to debug bundle in debug mode

* Update default metrics server URL to https://ingest.netbird.io

* Atomic metrics export-and-reset to prevent sample loss between Export and Reset calls

* Fix doc

* Refactor Push configuration to improve clarity and enforce minimum push interval

* Remove `minPushInterval` and update push interval validation logic

* Revert ExportAndReset, it is acceptable data loss

* Fix metrics review issues: rename env var, remove stale infra, add tests

- Rename NB_METRICS_ENABLED to NB_METRICS_PUSH_ENABLED to clarify that
  collection is always active (for debug bundles) and only push is opt-in
- Change default config URL from staging to production (ingest.netbird.io)
- Delete broken Prometheus dashboard (used non-existent metric names)
- Delete unused VictoriaMetrics datasource config
- Replace committed .env with .env.example containing placeholder values
- Wire Grafana admin credentials through env vars in docker-compose
- Make metricsStages a pointer to prevent reset-vs-write race on reconnect
- Fix typed-nil interface in debug bundle path (GetClientMetrics)
- Use deterministic field order in InfluxDB Export (sorted keys)
- Replace Authorization header with X-Peer-ID for metrics push
- Fix ingest server timeout to use time.Second instead of float
- Fix gzip double-close, stale comments, trim log levels
- Add tests for influxdb.go and MetricsStages

* Add login duration metric, ingest tag validation, and duration bounds

- Add netbird_login measurement recording login/auth duration to management
  server, with success/failure result tag
- Validate InfluxDB tags against per-measurement allowlists in ingest server
  to prevent arbitrary tag injection
- Cap all duration fields (*_seconds) at 300s instead of only total_seconds
- Add ingest server tests for tag/field validation, bounds, and auth

* Add arch tag to all metrics

* Fix Grafana dashboard: add arch to drop columns, add login panels

* Validate NB_METRICS_SERVER_URL is an absolute HTTP(S) URL

* Address review comments: fix README wording, update stale comments

* Clarify env var precedence does not bypass remote config eligibility

* Remove accidentally committed pprof files

---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
2026-03-22 12:45:41 +01:00
Viktor Liu
82762280ee [client] Add health check flag to status command and expose daemon status in output (#5650) 2026-03-22 12:39:40 +01:00
Viktor Liu
e2f774824b Add PeerCapability enum and disableIPv6 flag to proto
PeerCapability is reported in PeerSystemMeta.capabilities on login/sync.
Management uses it instead of version gating to determine client features.
disableIPv6 in Flags lets users opt out of IPv6 overlay.
2026-03-21 18:20:34 +01:00
Viktor Liu
3963072c43 Rename source_prefixes to sourcePrefixes for consistent JSON naming 2026-03-21 18:20:34 +01:00
Viktor Liu
8550765f38 Validate prefix length bounds in DecodePrefix 2026-03-21 18:20:34 +01:00
Viktor Liu
67fb6be40a Use copy into fixed arrays to satisfy gosec bounds checking 2026-03-21 18:20:34 +01:00
Viktor Liu
cd7290a497 Rename peer_prefixes to source_prefixes in FirewallRule 2026-03-21 18:20:34 +01:00
Viktor Liu
63c19dbf2e Rename peer_ips to peer_prefixes and simplify EncodePrefix with AsSlice 2026-03-21 18:20:34 +01:00
Viktor Liu
01c4d5761d Fix gosec and staticcheck lint errors from proto deprecation 2026-03-21 18:20:34 +01:00
Viktor Liu
e916e0d7fa Add proto fields for IPv6 overlay and compact IP encoding 2026-03-21 18:20:34 +01:00
Viktor Liu
b550a2face [management, proxy] Add require_subdomain capability for proxy clusters (#5628) 2026-03-20 11:29:50 +01:00
Viktor Liu
ab77508950 [client] Add env var for management gRPC max receive message size (#5622) 2026-03-19 17:33:50 +01:00
Viktor Liu
b9462f5c6b [client] Make raw table initialization non-fatal in firewall managers (#5621) 2026-03-19 17:33:38 +01:00
Viktor Liu
5ffaa5cdd6 [client] Fix duplicate log lines in containers (#5609) 2026-03-19 15:53:05 +01:00
Pascal Fischer
a1858a9cb7 [management] recover proxies after cleanup if heartbeat is still running (#5617) 2026-03-18 11:48:38 +01:00
Viktor Liu
212b34f639 [management] Add GET /reverse-proxies/clusters endpoint (#5611) 2026-03-18 11:15:56 +08:00
Viktor Liu
af8eaa23e2 [client] Restart engine when peer IP address changes (#5614) 2026-03-17 17:00:24 +01:00
Viktor Liu
f0eed50678 [management] Accept domain target type for L4 reverse proxy services (#5612) 2026-03-17 16:29:03 +01:00
Wouter van Os
19d94c6158 [client] Allow setting DNSLabels on client embed (#5493) 2026-03-17 16:12:37 +01:00
Viktor Liu
628eb56073 [client] Update go-m1cpu to v0.2.0 to fix SIGSEGV on macOS Tahoe (#5613) 2026-03-17 16:10:38 +01:00
eason
a590c38d8b [client] Fix IPv6 address formatting in DNS address construction (#5603)
Replace fmt.Sprintf("%s:%d", ip, port) with net.JoinHostPort() to
properly handle IPv6 addresses that need bracket wrapping (e.g.,
[2606:4700:4700::1111]:53 instead of 2606:4700:4700::1111:53).

Without this fix, configuring IPv6 nameservers causes "too many colons
in address" errors because Go's net.Dial cannot parse the malformed
address string.

Fixes #5601
Related to #4074

Co-authored-by: easonysliu <easonysliu@tencent.com>
2026-03-17 06:27:47 +01:00
Wesley Gimenes
4e149c9222 [client] update gvisor to build with Go 1.26.x (#5447)
Building the client with Go 1.26.x fails with errors:

```
[...]
/builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go126.go:22:2: WaitReasonSelect redeclared in this block
	/builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go125.go:22:2: other declaration of WaitReasonSelect
/builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go126.go:23:2: WaitReasonChanReceive redeclared in this block
	/builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go125.go:23:2: other declaration of WaitReasonChanReceive
/builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go126.go:24:2: WaitReasonSemacquire redeclared in this block
	/builder/dl/go-mod-cache/gvisor.dev/gvisor@v0.0.0-20251031020517-ecfcdd2f171c/pkg/sync/runtime_constants_go125.go:24:2: other declaration of WaitReasonSemacquire
[...]
```

Fixes: https://github.com/netbirdio/netbird/issues/5290 ("Does not build with Go 1.26rc3")

Signed-off-by: Wesley Gimenes <wehagy@proton.me>
2026-03-17 06:09:12 +01:00
tham-le
59f5b34280 [client] add MTU option to embed.Options (#5550)
Expose MTU configuration in the embed package so embedded clients
can set the WireGuard tunnel MTU without the config file workaround.
This is needed for protocols like QUIC that require larger datagrams
than the default MTU of 1280.

Validates MTU range via iface.ValidateMTU() at construction time to
prevent invalid values from being persisted to config.

Closes #5549
2026-03-17 06:03:10 +01:00
n0pashkov
dff06d0898 [misc] Add netbird-tui to community projects (#5568) 2026-03-17 05:33:13 +01:00
Pascal Fischer
80a8816b1d [misc] Add image build after merge to main (#5605) 2026-03-16 18:00:23 +01:00
Viktor Liu
387e374e4b [proxy, management] Add header auth, access restrictions, and session idle timeout (#5587) 2026-03-16 15:22:00 +01:00
Viktor Liu
3e6baea405 [management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530) 2026-03-13 18:36:44 +01:00
Zoltan Papp
fe9b844511 [client] refactor auto update workflow (#5448)
Auto-update logic moved out of the UI into a dedicated updatemanager.Manager service that runs in the connection layer. The
UI no longer polls or checks for updates independently.
The update manager supports three modes driven by the management server's auto-update policy:
No policy set by mgm: checks GitHub for the latest version and notifies the user (previous behavior, now centralized)
mgm enforces update: the "About" menu triggers installation directly instead of just downloading the file — user still initiates the action
mgm forces update: installation proceeds automatically without user interaction
updateManager lifecycle is now owned by daemon, giving the daemon server direct control via a new TriggerUpdate RPC
Introduces EngineServices struct to group external service dependencies passed to NewEngine, reducing its argument count from 11 to 4
2026-03-13 17:01:28 +01:00
Pascal Fischer
2e1aa497d2 [proxy] add log-level flag (#5594) 2026-03-13 15:28:25 +01:00
Viktor Liu
529c0314f8 [client] Fall back to getent/id for SSH user lookup in static builds (#5510) 2026-03-13 15:22:02 +01:00
Pascal Fischer
d86875aeac [management] Exclude proxy from peer approval (#5588) 2026-03-13 15:01:59 +01:00
Zoltan Papp
f80fe506d5 [client] Fix DNS probe thread safety and avoid blocking engine sync (#5576)
* Fix DNS probe thread safety and avoid blocking engine sync

Refactor ProbeAvailability to prevent blocking the engine's sync mutex
during slow DNS probes. The probe now derives its context from the
server's own context (s.ctx) instead of accepting one from the caller,
and uses a mutex to ensure only one probe runs at a time — new calls
cancel the previous probe before starting. Also fixes a data race in
Stop() when accessing probeCancel without the probe mutex.

* Ensure DNS probe thread safety by locking critical sections

Add proper locking to prevent data races when accessing shared resources during DNS probe execution and Stop(). Update handlers snapshot logic to avoid conflicts with concurrent writers.

* Rename context and remove redundant cancellation

* Cancel first and lock

* Add locking to ensure thread safety when reactivating upstream servers
2026-03-13 13:22:43 +01:00
Maycon Santos
967c6f3cd3 [misc] Add GPG signing key support for rpm packages (#5581)
* [misc] Add GPG signing key support for deb and rpm packages

* [misc] Improve GPG key management for deb and rpm signing

* [misc] Extract GPG key import logic into a reusable script

* [misc] Add key fingerprint extraction and targeted export for GPG keys

* [misc] Remove passphrase from GPG keys before exporting

* [misc] Simplify GPG key management by removing import script

* [misc] Bump GoReleaser version to v2.14.3 in release workflow

* [misc] Replace GPG passphrase variables with NFPM-prefixed alternatives in workflows and configs

* [misc] Update naming conventions for package IDs and passphrase variables in workflows and configs

* [misc] Standardize NFPM variable naming in release workflow

* [misc] Adjust NFPM variable names for consistency in release workflow

* [misc] Remove Debian signing GPG key usage in workflows and configs
2026-03-13 09:47:00 +01:00
Pascal Fischer
e50e124e70 [proxy] Fix domain switching update (#5585) 2026-03-12 17:12:26 +01:00
Pascal Fischer
c545689448 [proxy] Wildcard certificate support (#5583) 2026-03-12 16:00:28 +01:00
Vlad
8f389fef19 [management] fix some concurrency potential issues (#5584) 2026-03-12 15:57:36 +01:00
Pascal Fischer
d3d6a327e0 [proxy] read cert from disk if available instead of cert manager (#5574)
* **New Features**
  * Asynchronous certificate prefetch that races live issuance with periodic on-disk cache checks to surface certificates faster.
  * Centralized recording and notification when certificates become available.
  * New on-disk certificate reading and validation to allow immediate use of cached certs.

* **Bug Fixes & Performance**
  * Optimized retrieval by polling disk while fetching in background to reduce latency.
  * Added cancellation and timeout handling to fail stalled certificate operations reliably.
2026-03-11 19:18:37 +01:00
Vlad
b5489d4986 [management] set components network map by default and optimize memory usage (#5575)
* Network map now defaults to compacted mode at startup; environment parsing issues yield clearer warnings and disabling compacted mode is logged.

* **Bug Fixes**
  * DNS enablement and nameserver selection now correctly respect group membership, reducing incorrect DNS assignments.

* **Refactor**
  * Internal routing and firewall rule generation streamlined for more consistent rule IDs and safer peer handling.

* **Performance**
  * Minor memory and slice allocation improvements for peer/group processing.
2026-03-11 18:19:17 +01:00
Maycon Santos
7a23c57cf8 [self-hosted] Remove extra proxy domain from getting started (#5573) 2026-03-11 15:52:42 +01:00
Pascal Fischer
11f891220e [management] create a shallow copy of the account when buffering (#5572) 2026-03-11 13:01:13 +01:00
Pascal Fischer
5585adce18 [management] add activity events for domains (#5548)
* add activity events for domains

* fix test

* update activity codes

* update activity codes
2026-03-09 19:04:04 +01:00
Pascal Fischer
f884299823 [proxy] refactor metrics and add usage logs (#5533)
* **New Features**
  * Access logs now include bytes_upload and bytes_download (API and schemas updated, fields required).
  * Certificate issuance duration is now recorded as a metric.

* **Refactor**
  * Metrics switched from Prometheus client to OpenTelemetry-backed meters; health endpoint now exposes OpenMetrics via OTLP exporter.

* **Tests**
  * Metric tests updated to use OpenTelemetry Prometheus exporter and MeterProvider.
2026-03-09 18:45:45 +01:00
Maycon Santos
15aa6bae1b [client] Fix exit node menu not refreshing on Windows (#5553)
* [client] Fix exit node menu not refreshing on Windows

TrayOpenedCh is not implemented in the systray library on Windows,
so exit nodes were never refreshed after the initial connect. Combined
with the management sync not having populated routes yet when the
Connected status fires, this caused the exit node menu to remain empty
permanently after disconnect/reconnect cycles.
Add a background poller on Windows that refreshes exit nodes while
connected, with fast initial polling to catch routes from management
sync followed by a steady 10s interval. On macOS/Linux, TrayOpenedCh
continues to handle refreshes on each tray open.
Also fix a data race on connectClient assignment in the server's connect()
method and add nil checks in CleanState/DeleteState to prevent panics
when connectClient is nil.

* Remove unused exitNodeIDs

* Remove unused exitNodeState struct
2026-03-09 18:39:11 +01:00
Pascal Fischer
11eb725ac8 [management] only count login request duration for successful logins (#5545) 2026-03-09 14:56:46 +01:00
Pascal Fischer
30c02ab78c [management] use the cache for the pkce state (#5516) 2026-03-09 12:23:06 +01:00
Zoltan Papp
3acd86e346 [client] "reset connection" error on wake from sleep (#5522)
Capture engine reference before actCancel() in cleanupConnection().

After actCancel(), the connectWithRetryRuns goroutine sets engine to nil,
causing connectClient.Stop() to skip shutdown. This allows the goroutine
to set ErrResetConnection on the shared state after Down() clears it,
causing the next Up() to fail.
2026-03-09 10:25:51 +01:00
Pascal Fischer
5c20f13c48 [management] fix domain uniqueness (#5529) 2026-03-07 10:46:37 +01:00
Pascal Fischer
e6587b071d [management] use realip for proxy registration (#5525) 2026-03-06 16:11:44 +01:00
Maycon Santos
85451ab4cd [management] Add stable domain resolution for combined server (#5515)
The combined server was using the hostname from exposedAddress for both
singleAccountModeDomain and dnsDomain, causing fresh installs to get
the wrong domain and existing installs to break if the config changed.
 Add resolveDomains() to BaseServer that reads domain from the store:
  - Fresh install (0 accounts): uses "netbird.selfhosted" default
  - Existing install: reads persisted domain from the account in DB
  - Store errors: falls back to default safely

The combined server opts in via AutoResolveDomains flag, while the
 standalone management server is unaffected.
2026-03-06 08:43:46 +01:00
Pascal Fischer
a7f3ba03eb [management] aggregate grpc metrics by accountID (#5486) 2026-03-05 22:10:45 +01:00
Maycon Santos
4f0a3a77ad [management] Avoid breaking single acc mode when switching domains (#5511)
* **Bug Fixes**
  * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data.
  * Improved error handling when account data is unavailable with fallback to configured default domain.

* **Tests**
  * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data.
2026-03-05 14:30:31 +01:00
Maycon Santos
44655ca9b5 [misc] add PR title validation workflow (#5503) 2026-03-05 11:43:18 +01:00
Viktor Liu
e601278117 [management,proxy] Add per-target options to reverse proxy (#5501) 2026-03-05 10:03:26 +01:00
Maycon Santos
8e7b016be2 [management] Replace in-memory expose tracker with SQL-backed operations (#5494)
The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart.

Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column:

- Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists
- Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass
- Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission 
- Add composite index on (source, source_peer) for efficient queries
- Batch-limit and column-select the reaper query to avoid DB/GC spikes
- Filter out malformed rows with empty source_peer
2026-03-04 18:15:13 +01:00
Maycon Santos
9e01ea7aae [misc] Add ISSUE_TEMPLATE configuration file (#5500)
Add issue template config file  with support and troubleshooting links
2026-03-04 14:30:54 +01:00
hbzhost
cfc7ec8bb9 [client] Fix SSH JWT auth failure with Azure Entra ID iat backdating (#5471)
Increase DefaultJWTMaxTokenAge from 5 to 10 minutes to accommodate
identity providers like Azure Entra ID that backdate the iat claim
by up to 5 minutes, causing tokens to be immediately rejected.

Fixes #5449

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-04 14:11:14 +01:00
Misha Bragin
b3bbc0e5c6 Fix embedded IdP metrics to count local and generic OIDC users (#5498) 2026-03-04 12:34:11 +02:00
Pascal Fischer
d7c8e37ff4 [management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-03-03 18:39:46 +01:00
Zoltan Papp
05b66e73bc [client] Fix deadlock in route peer status watcher (#5489)
Wrap peerStateUpdate send in a nested select to prevent goroutine
blocking when the consumer has exited, which could fill the
subscription buffer and deadlock the Status mutex.
2026-03-03 13:50:46 +01:00
Jeremie Deray
01ceedac89 [client] Fix profile config directory permissions (#5457)
* fix user profile dir perm

* fix fileExists

* revert return var change

* fix anti-pattern
2026-03-03 13:48:51 +01:00
Misha Bragin
403babd433 [self-hosted] specify sql file location of auth, activity and main store (#5487) 2026-03-03 12:53:16 +02:00
Maycon Santos
47133031e5 [client] fix: client/Dockerfile to reduce vulnerabilities (#5217)
Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-03-03 08:44:08 +01:00
Pascal Fischer
82da606886 [management] Add explicit target delete on service removal (#5420) 2026-03-02 18:25:44 +01:00
Viktor Liu
bbe5ae2145 [client] Flush buffer immediately to support gprc (#5469) 2026-03-02 15:17:08 +01:00
Viktor Liu
0b21498b39 [client] Fix close of closed channel panic in ConnectClient retry loop (#5470) 2026-03-02 10:07:53 +01:00
Viktor Liu
0ca59535f1 [management] Add reverse proxy services REST client (#5454) 2026-02-28 13:04:58 +08:00
Misha Bragin
59c77d0658 [self-hosted] support embedded IDP postgres db (#5443)
* Add postgres config for embedded idp

Entire-Checkpoint: 9ace190c1067

* Rename idpStore to authStore

Entire-Checkpoint: 73a896c79614

* Fix review notes

Entire-Checkpoint: 6556783c0df3

* Don't accept pq port = 0

Entire-Checkpoint: 80d45e37782f

* Optimize configs

Entire-Checkpoint: 80d45e37782f

* Fix lint issues

Entire-Checkpoint: 3eec968003d1

* Fail fast on combined postgres config

Entire-Checkpoint: b17839d3d8c6

* Simplify management config method

Entire-Checkpoint: 0f083effa20e
2026-02-27 14:52:54 +01:00
shuuri-labs
333e045099 Lower socket auto-discovery log from Info to Debug (#5463)
The discovery message was printing on every CLI invocation, which is
noisy for users on distros using the systemd template.
2026-02-26 17:51:38 +01:00
Zoltan Papp
c2c4d9d336 [client] Fix Server mutex held across waitForUp in Up() (#5460)
Up() acquired s.mutex with a deferred unlock, then called waitForUp()
while still holding the lock. waitForUp() blocks for up to 50 seconds
waiting on clientRunningChan/clientGiveUpChan, starving all concurrent
gRPC calls that require the same mutex (Status, ListProfiles, etc.).

Replace the deferred unlock with explicit s.mutex.Unlock() on every
early-return path and immediately before waitForUp(), matching the
pattern already used by the clientRunning==true branch.
2026-02-26 16:47:02 +01:00
Bethuel Mmbaga
9a6a72e88e [management] Fix user update permission validation (#5441) 2026-02-24 22:47:41 +03:00
Bethuel Mmbaga
afe6d9fca4 [management] Prevent deletion of groups linked to flow groups (#5439) 2026-02-24 21:19:43 +03:00
shuuri-labs
ef82905526 [client] Add non default socket file discovery (#5425)
- Automatic Unix daemon address discovery: if the default socket is missing, the client can find and use a single available socket.
- Client startup now resolves daemon addresses more robustly while preserving non-Unix behavior.
2026-02-24 17:02:06 +01:00
Zoltan Papp
d18747e846 [client] Exclude Flow domain from caching to prevent TLS failures (#5433)
* Exclude Flow domain from caching to prevent TLS failures due to stale records.

* Fix test
2026-02-24 16:48:38 +01:00
Maycon Santos
f341d69314 [management] Add custom domain counts and service metrics to self-hosted metrics (#5414) 2026-02-24 15:21:14 +01:00
Maycon Santos
327142837c [management] Refactor expose feature: move business logic from gRPC to manager (#5435)
Consolidate all expose business logic (validation, permission checks, TTL tracking, reaping) into the manager layer, making the gRPC layer a pure transport adapter that only handles proto conversion and authentication.

- Add ExposeServiceRequest/ExposeServiceResponse domain types with validation in the reverseproxy package
- Move expose tracker (TTL tracking, reaping, per-peer limits) from gRPC server into manager/expose_tracker.go
- Internalize tracking in CreateServiceFromPeer, RenewServiceFromPeer, and new StopServiceFromPeer so callers don't manage tracker state
- Untrack ephemeral services in DeleteService/DeleteAllServices to keep tracker in sync when services are deleted via API
- Simplify gRPC expose handlers to parse, auth, convert, delegate
- Remove tracker methods from Manager interface (internal detail)
2026-02-24 15:09:30 +01:00
Zoltan Papp
f8c0321aee [client] Simplify DNS logging by removing domain list from log output (#5396) 2026-02-24 10:35:45 +01:00
Zoltan Papp
89115ff76a [client] skip UAPI listener in netstack mode (#5397)
In netstack (proxy) mode, the process lacks permission to create
/var/run/wireguard, making the UAPI listener unnecessary and causing
a misleading error log. Introduce NewUSPConfigurerNoUAPI and use it
for the netstack device to avoid attempting to open the UAPI socket
entirely. Also consolidate UAPI error logging to a single call site.
2026-02-24 10:35:23 +01:00
Maycon Santos
63c83aa8d2 [client,management] Feature/client service expose (#5411)
CLI: new expose command to publish a local port with flags for PIN, password, user groups, custom domain, name prefix and protocol (HTTP default).
Management/API: create/renew/stop expose sessions (streamed status), automatic naming/domain, TTL renewals, background expiration, new management RPCs and client methods.
UI/API: account settings now include peer_expose_enabled and peer_expose_groups; new activity codes for peer expose events.
2026-02-24 10:02:16 +01:00
Zoltan Papp
37f025c966 Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state
transitions. The mutex now covers the full duration of each handler
including the status check, the Up/Down call, and the flag update.

Note: if Up or Down commands are triggered in parallel with sleep/wake
events, the overall ordering of up/down/sleep/wake operations is still
not guaranteed beyond what the mutex provides within the handler itself.
2026-02-24 10:00:33 +01:00
Zoltan Papp
4a54f0d670 [Client] Remove connection semaphore (#5419)
* [Client] Remove connection semaphore

Remove the semaphore and the initial random sleep time (300ms) from the connectivity logic to speed up the initial connection time.

Note: Implement limiter logic that can prioritize router peers and keep the fast connection option for the first few peers.

* Remove unused function
2026-02-23 20:58:53 +01:00
Zoltan Papp
98890a29e3 [client] fix busy-loop in network monitor routing socket on macOS/BSD (#5424)
* [client] fix busy-loop in network monitor routing socket on macOS/BSD

After system wakeup, the AF_ROUTE socket created by Go's unix.Socket()
is non-blocking, causing unix.Read to return EAGAIN immediately and spin
at 100% CPU filling the log with thousands of warnings per second.

Replace the tight read loop with a unix.Select call that blocks until
the fd is readable, checking ctx cancellation on each 1-second timeout.
Fatal errors (EBADF, EINVAL) now return an error instead of looping.

* [client] add fd range validation in waitReadable to prevent out-of-bound errors
2026-02-23 20:58:27 +01:00
Pascal Fischer
9d123ec059 [proxy] add pre-shared key support (#5377) 2026-02-23 16:31:29 +01:00
Pascal Fischer
5d171f181a [proxy] Send proxy updates on account delete (#5375) 2026-02-23 16:08:28 +01:00
Vlad
22f878b3b7 [management] network map components assembling (#5193) 2026-02-23 15:34:35 +01:00
Misha Bragin
44ef1a18dd [self-hosted] add Embedded IdP metrics (#5407) 2026-02-22 11:58:35 +02:00
Misha Bragin
2b98dc4e52 [self-hosted] Support activity store engine in the combined server (#5406) 2026-02-22 11:58:17 +02:00
Zoltan Papp
2a26cb4567 [client] stop upstream retry loop immediately on context cancellation (#5403)
stop upstream retry loop immediately on context cancellation
2026-02-20 14:44:14 +01:00
Pascal Fischer
5ca1b64328 [management] access log sorting (#5378) 2026-02-20 00:11:55 +01:00
Pascal Fischer
36752a8cbb [proxy] add access log cleanup (#5376) 2026-02-20 00:11:28 +01:00
Maycon Santos
f117fc7509 [client] Log lock acquisition time in receive message handling (#5393)
* Log lock acquisition time in receive message handling

* use offerAnswer.SessionID for session id
2026-02-19 19:18:47 +01:00
Zoltan Papp
fc6b93ae59 [ios] Ensure route settlement on iOS before handling DNS responses (#5360)
* Ensure route settlement on iOS before handling DNS responses to prevent bypassing the tunnel.

* add more logs

* rollback debug changes

* rollback  changes

* [client] Improve logging and add comments for iOS route settlement logic

- Switch iOS route settlement log level from Debug to Trace for finer control.
- Add clarifying comments for `waitForRouteSettlement` on non-iOS platforms.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-02-19 18:53:10 +01:00
Vlad
564fa4ab04 [management] fix possible race condition on user role change (#5395) 2026-02-19 18:34:28 +01:00
Maycon Santos
a6db88fbd2 [misc] Update timestamp format with milliseconds (#5387)
* Update timestamp format with milliseconds

* fix tests
2026-02-19 11:23:42 +01:00
Misha Bragin
4b5294e596 [self-hosted] remove unused config example (#5383) 2026-02-19 08:14:11 +01:00
shuuri-labs
a322dce42a [self-hosted] create migration script for pre v0.65.0 to post v0.65.0 (combined) (#5350) 2026-02-18 20:59:55 +01:00
Maycon Santos
d1ead2265b [client] Batch macOS DNS domains to avoid truncation (#5368)
* [client] Batch macOS DNS domains across multiple scutil keys to avoid truncation

scutil has undocumented limits: 99-element cap on d.add arrays and ~2048
  byte value buffer for SupplementalMatchDomains. Users with 60+ domains
  hit silent domain loss. This applies the same batching approach used on
  Windows (nrptMaxDomainsPerRule=50), splitting domains into indexed
  resolver keys (NetBird-Match-0, NetBird-Match-1, etc.) with 50-element
  and 1500-byte limits per key.

* check for all keys on getRemovableKeysWithDefaults

* use multi error
2026-02-18 19:14:09 +01:00
Maycon Santos
bbca74476e [management] docker login on management tests (#5323) 2026-02-18 16:11:17 +01:00
Zoltan Papp
318cf59d66 [relay] reduce QUIC initial packet size to 1280 (IPv6 min MTU) (#5374)
* [relay] reduce QUIC initial packet size to 1280 (IPv6 min MTU)

* adjust QUIC initial packet size to 1232 based on RFC 9000 §14
2026-02-18 10:58:14 +01:00
Pascal Fischer
e9b2a6e808 [managment] add flag to disable the old legacy grpc endpoint (#5372) 2026-02-17 19:53:14 +01:00
Zoltan Papp
2dbdb5c1a7 [client] Refactor WG endpoint setup with role-based proxy activation (#5277)
* Refactor WG endpoint setup with role-based proxy activation

For relay connections, the controller (initiator) now activates the
wgProxy before configuring the WG endpoint, while the non-controller
(responder) configures the endpoint first with a delayed update, then
activates the proxy after. This prevents the responder from sending
traffic through the proxy before WireGuard is ready to receive it,
avoiding handshake congestion when both sides try to initiate
simultaneously.

For ICE connections, pass hasRelayBackup as the setEndpointNow flag
so the responder sets the endpoint immediately when a relay fallback
exists (avoiding the delayed update path since relay is already
available as backup).

On ICE disconnect with relay fallback, remove the duplicate
wgProxyRelay.Work() calls — the relay proxy is already active from
initial setup, so re-activating it is unnecessary.

In EndpointUpdater, split ConfigureWGEndpoint into explicit
configureAsInitiator and configureAsResponder paths, and add the
setEndpointNow parameter to let the caller control whether the
responder applies the endpoint immediately or defers it. Add unused
SwitchWGEndpoint and RemoveEndpointAddress methods. Remove the
wgConfigWorkaround sleep from the relay setup path.

* Fix redundant wgProxyRelay.Work() call during relay fallback setup

* Simplify WireGuard endpoint configuration by removing unused parameters and redundant logic
2026-02-17 19:28:26 +01:00
Pascal Fischer
2cdab6d7b7 [proxy] remove unused oidc config flags (#5369) 2026-02-17 18:04:30 +01:00
Diego Noguês
e49c0e8862 [infrastructure] Proxy infra changes (#5365)
* chore: remove docker extra_hosts settings

* chore: remove unnecessary envc from proxy.env
2026-02-17 17:37:44 +01:00
Misha Bragin
e7c84d0ead Start Management if external IdP is down (#5367)
Set ContinueOnConnectorFailure: true in the embedded Dex config so that the Management server starts successfully even when an external IdP connector is unreachable at boot time.
2026-02-17 16:08:41 +01:00
Zoltan Papp
1c934cca64 Ignore false lint alert (#5370) 2026-02-17 16:07:35 +01:00
Vlad
4aff4a6424 [management] fix utc difference on last seen status for a peer (#5348) 2026-02-17 13:29:32 +01:00
Zoltan Papp
1bd7190954 [proxy] Support WebSocket (#5312)
* Fix WebSocket support by implementing Hijacker interface

Add responsewriter.PassthroughWriter to preserve optional HTTP interfaces
(Hijacker, Flusher, Pusher) when wrapping http.ResponseWriter in middleware.

Without this delegation:
 - WebSocket connections fail (can't hijack the connection)
 - Streaming breaks (can't flush buffers)
 - HTTP/2 push doesn't work

* Add HijackTracker to manage hijacked connections during graceful shutdown

* Refactor HijackTracker to use middleware for tracking hijacked connections

* Refactor server handler chain setup for improved readability and maintainability
2026-02-17 12:53:34 +01:00
Viktor Liu
0146e39714 Add listener side proxy protocol support and enable it in traefik (#5332)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-02-16 23:40:10 +01:00
Zoltan Papp
baed6e46ec Reset WireGuard endpoint on ICE session change during relay fallback (#5283)
When an ICE connection disconnects and falls back to relay, reset the
WireGuard endpoint and handshake watcher if the remote peer's ICE session
has changed. This ensures the controller re-establishes a fresh WireGuard
handshake rather than waiting on a stale endpoint from the previous session.
2026-02-16 20:59:29 +01:00
Maycon Santos
0d1ffba75f [misc] add additional cname example (#5341) 2026-02-16 13:30:58 +01:00
Diego Romar
1024d45698 [mobile] Export lazy connection environment variables for mobile clients (#5310)
* [client] Export lazy connection env vars

Both for Android and iOS

* [client] Separate comments
2026-02-16 09:04:45 -03:00
Zoltan Papp
e5d4947d60 [client] Optimize Windows DNS performance with domain batching and batch mode (#5264)
* Optimize Windows DNS performance with domain batching and batch mode

Implement two-layer optimization to reduce Windows NRPT registry operations:

1. Domain Batching (host_windows.go):
  - Batch domains per NRPT
  - Reduces NRPT rules by ~97% (e.g., 184 domains: 184 rules → 4 rules)
  - Modified addDNSMatchPolicy() to create batched NRPT entries
  - Added comprehensive tests in host_windows_test.go

2. Batch Mode (server.go):
  - Added BeginBatch/EndBatch methods to defer DNS updates
  - Modified RegisterHandler/DeregisterHandler to skip applyHostConfig in batch mode
  - Protected all applyHostConfig() calls with batch mode checks
  - Updated route manager to wrap route operations with batch calls

* Update tests

* Fix log line

* Fix NRPT rule index to ensure cleanup covers partially created rules

* Ensure NRPT entry count updates even on errors to improve cleanup reliability

* Switch DNS batch mode logging from Info to Debug level

* Fix batch mode to not suppress critical DNS config updates

Batch mode should only defer applyHostConfig() for RegisterHandler/
DeregisterHandler operations. Management updates and upstream nameserver
failures (deactivate/reactivate callbacks) need immediate DNS config
updates regardless of batch mode to ensure timely failover.

Without this fix, if a nameserver goes down during a route update,
the system DNS config won't be updated until EndBatch(), potentially
delaying failover by several seconds.

Or if you prefer a shorter version:

Fix batch mode to allow immediate DNS updates for critical paths

Batch mode now only affects RegisterHandler/DeregisterHandler.
Management updates and nameserver failures always trigger immediate
DNS config updates to ensure timely failover.

* Add DNS batch cancellation to rollback partial changes on errors

Introduces CancelBatch() method to the DNS server interface to handle error
scenarios during batch operations. When route updates fail partway through, the DNS
server can now discard accumulated changes instead of applying partial state. This
prevents leaving the DNS configuration in an inconsistent state when route manager
operations encounter errors.

The changes add error-aware batch handling to prevent partial DNS configuration
updates when route operations fail, which improves system reliability.
2026-02-15 22:10:26 +01:00
Maycon Santos
cb9b39b950 [misc] add extra proxy domain instructions (#5328)
improve proxy domain instructions
expose wireguard port
2026-02-15 12:51:46 +01:00
Bethuel Mmbaga
68c481fa44 [management] Move service reload outside transaction in account settings update (#5325)
Bug Fixes

Network and DNS updates now defer service and reverse-proxy reloads until after account updates complete, preventing inconsistent proxy state and race conditions.
Chores

Removed automatic peer/broadcast updates immediately following bulk service reloads.
Tests

Added a test ensuring network-range changes complete without deadlock.
2026-02-14 20:27:15 +01:00
Misha Bragin
01a9cd4651 [misc] Fix reverse proxy getting started messaging (#5317)
* Fix reverse proxy getting started messaging

* Fix reverse proxy getting started messaging
2026-02-14 16:34:04 +01:00
Pascal Fischer
f53155562f [management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
2026-02-13 19:37:43 +01:00
Zoltan Papp
edce11b34d [client] Refactor/relay conn container (#5271)
* Fix race condition and ensure correct message ordering in
connection establishment

Reorder operations in OpenConn to register the connection before
waiting for peer availability. This ensures:

- Connection is ready to receive messages before peer subscription
completes
- Transport messages and onconnected events maintain proper ordering
- No messages are lost during the connection establishment window
- Concurrent OpenConn calls cannot create duplicate connections

If peer availability check fails, the pre-registered connection is
properly cleaned up.

* Handle service shutdown during relay connection initialization

Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors.

* Refactor relay client Conn/connContainer ownership and decouple Conn from Client

Conn previously held a direct *Client pointer and called client methods
(writeTo, closeConn, LocalAddr) directly, creating a tight bidirectional
coupling. The message channel was also created externally in OpenConn and
shared between Conn and connContainer with unclear ownership.

Now connContainer fully owns the lifecycle of both the channel and the
Conn it wraps:
- connContainer creates the channel (sized by connChannelSize const)
  and the Conn internally via newConnContainer
- connContainer feeds messages into the channel (writeMsg), closes and
  drains it on shutdown (close)
- Conn reads from the channel (Read) but never closes it

Conn is decoupled from *Client by replacing the *Client field with
three function closures (writeFn, closeFn, localAddrFn) that are wired
by newConnContainer at construction time. Write, Close, and LocalAddr
delegate to these closures. This removes the direct dependency while
keeping the identity-check logic: writeTo and closeConn now compare
connContainer pointers instead of Conn pointers to verify the caller
is the current active connection for that peer.
2026-02-13 15:48:08 +01:00
Zoltan Papp
841b2d26c6 Add early message buffer for relay client (#5282)
Add early message buffer to capture transport messages
arriving before OpenConn completes, ensuring correct
message ordering and no dropped messages.
2026-02-13 15:41:26 +01:00
Bethuel Mmbaga
d3eeb6d8ee [misc] Add cloud api spec to public open api with rest client (#5222) 2026-02-13 15:08:47 +03:00
Bethuel Mmbaga
7ebf37ef20 [management] Enforce access control on accessible peers (#5301) 2026-02-13 12:46:43 +03:00
Misha Bragin
64b849c801 [self-hosted] add netbird server (#5232)
* Unified NetBird combined server (Management, Signal, Relay, STUN) as a single executable with richer YAML configuration, validation, and defaults.
  * Official Dockerfile/image for single-container deployment.
  * Optional in-process profiling endpoint for diagnostics.
  * Multiplexing to route HTTP/gRPC/WebSocket traffic via one port; runtime hooks to inject custom handlers.
* **Chores**
  * Updated deployment scripts, compose files, and reverse-proxy templates to target the combined server; added example configs and getting-started updates.
2026-02-12 19:24:43 +01:00
Maycon Santos
69d4b5d821 [misc] Update sign pipeline version (#5296) 2026-02-12 11:31:49 +01:00
Viktor Liu
3dfa97dcbd [client] Fix stale entries in nftables with no handle (#5272) 2026-02-12 09:15:57 +01:00
Viktor Liu
1ddc9ce2bf [client] Fix nil pointer panic in device and engine code (#5287) 2026-02-12 09:15:42 +01:00
Maycon Santos
2de1949018 [client] Check if login is required on foreground mode (#5295) 2026-02-11 21:42:36 +01:00
Vlad
fc88399c23 [management] fixed ischild check (#5279) 2026-02-10 20:31:15 +03:00
Zoltan Papp
6981fdce7e [client] Fix race condition and ensure correct message ordering in Relay (#5265)
* Fix race condition and ensure correct message ordering in
connection establishment

Reorder operations in OpenConn to register the connection before
waiting for peer availability. This ensures:

- Connection is ready to receive messages before peer subscription
completes
- Transport messages and onconnected events maintain proper ordering
- No messages are lost during the connection establishment window
- Concurrent OpenConn calls cannot create duplicate connections

If peer availability check fails, the pre-registered connection is
properly cleaned up.

* Handle service shutdown during relay connection initialization

Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors.
2026-02-09 11:34:24 +01:00
Viktor Liu
08403f64aa [client] Add env var to skip DNS probing (#5270) 2026-02-09 11:09:11 +01:00
Viktor Liu
391221a986 [client] Fix uspfilter duplicate firewall rules (#5269) 2026-02-09 10:14:02 +01:00
Zoltan Papp
7bc85107eb Adds timing measurement to handleSync to help diagnose sync performance issues (#5228) 2026-02-06 19:50:48 +01:00
Zoltan Papp
3be16d19a0 [management] Feature/grpc debounce msgtype (#5239)
* Add gRPC update debouncing mechanism

Implements backpressure handling for peer network map updates to
efficiently handle rapid changes. First update is sent immediately,
subsequent rapid updates are coalesced, ensuring only the latest
update is sent after a 1-second quiet period.

* Enhance unit test to verify peer count synchronization with debouncing and timeout handling

* Debounce based on type

* Refactor test to validate timer restart after pending update dispatch

* Simplify timer reset for Go 1.23+ automatic channel draining

Remove manual channel drain in resetTimer() since Go 1.23+ automatically
drains the timer channel when Stop() returns false, making the
select-case pattern unnecessary.
2026-02-06 19:47:38 +01:00
Vlad
af8f730bda [management] check stream start time for connecting peer (#5267) 2026-02-06 18:00:43 +01:00
eyJhb
c3f176f348 [client] Fix wrong URL being logged for DefaultAdminURL (#5252)
- DefaultManagementURL was being logged instead of DefaultAdminURL
2026-02-06 11:23:36 +01:00
Viktor Liu
0119f3e9f4 [client] Fix netstack detection and add wireguard port option (#5251)
- Add WireguardPort option to embed.Options for custom port configuration
- Fix KernelInterface detection to account for netstack mode
- Skip SSH config updates when running in netstack mode
- Skip interface removal wait when running in netstack mode
- Use BindListener for netstack to avoid port conflicts on same host
2026-02-06 10:03:01 +01:00
Viktor Liu
1b96648d4d [client] Always log dns forwader responses (#5262) 2026-02-05 14:34:35 +01:00
Zoltan Papp
d2f9653cea Fix nil pointer panic in ICE agent during sleep/wake cycles (#5261)
Add defensive nil checks in ThreadSafeAgent.Close() to prevent panic
when agent field is nil. This can occur during Windows suspend/resume
when network interfaces are disrupted or the pion/ice library returns
nil without error.

Also capture agent pointer in local variable before goroutine execution
to prevent race conditions.

Fixes service crashes on laptop wake-up.
2026-02-05 12:06:28 +01:00
Zoltan Papp
194a986926 Cache the result of wgInterface.ToInterface() using sync.Once (#5256)
Avoid repeated conversions during route setup. The toInterface helper ensures
the conversion happens only once regardless of how many routes are added
or removed.
2026-02-04 22:22:37 +01:00
Viktor Liu
f7732557fa [client] Add missing bsd flags in debug bundle (#5254) 2026-02-04 18:07:27 +01:00
Vlad
d488f58311 [management] fix set disconnected status for connected peer (#5247) 2026-02-04 11:44:46 +01:00
Pascal Fischer
6fdc00ff41 [management] adding account id validation to accessible peers handler (#5246) 2026-02-03 17:30:02 +01:00
Misha Bragin
b20d484972 [docs] Add selfhosting video (#5235) 2026-02-01 16:06:36 +01:00
Vlad
8931293343 [management] run cancelPeerRoutinesWithoutLock in sync (#5234) 2026-02-01 15:44:27 +01:00
Vlad
7b830d8f72 disable sync lim (#5233) 2026-02-01 14:37:00 +01:00
Misha Bragin
3a0cf230a1 Disable local users for a smooth single-idp mode (#5226)
Add LocalAuthDisabled option to embedded IdP configuration

This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external
identity providers (Google, OIDC, etc.).

This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page.

Key changes:

Added LocalAuthDisabled field to EmbeddedIdPConfig
Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth
Validation prevents disabling local auth if no external connectors are configured
Existing local users are preserved when disabled and can login again when re-enabled
Operations are idempotent (disabling already disabled is a no-op)
2026-02-01 14:26:22 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
Zoltan Papp
5333e55a81 Fix WG watcher missing initial handshake (#5213)
Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp.

Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup.
2026-01-29 16:58:10 +01:00
Viktor Liu
81c11df103 [management] Streamline domain validation (#5211) 2026-01-29 13:51:44 +01:00
Viktor Liu
f74bc48d16 [Client] Stop NetBird on firewall init failure (#5208) 2026-01-29 11:05:06 +01:00
Vlad
0169e4540f [management] fix skip of ephemeral peers on deletion (#5206) 2026-01-29 10:58:45 +01:00
Vlad
cead3f38ee [management] fix ephemeral peers being not removed (#5203) 2026-01-28 18:24:12 +01:00
Zoltan Papp
b55262d4a2 [client] Refactor/optimise raw socket headers (#5174)
Pre-create and reuse packet headers to eliminate per-packet allocations.
2026-01-28 15:06:59 +01:00
Zoltan Papp
2248ff392f Remove redundant square bracket trimming in USP endpoint parsing (#5197) 2026-01-27 20:10:59 +01:00
755 changed files with 109811 additions and 7283 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

14
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,14 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -46,6 +46,5 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -97,6 +97,16 @@ jobs:
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
- name: Build combined
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 go build .
- name: Build combined 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
test:
name: "Client / Unit"
needs: [build-cache]
@@ -144,7 +154,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -204,7 +214,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:
@@ -261,6 +271,53 @@ jobs:
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
@@ -352,12 +409,19 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
@@ -440,15 +504,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: Test
run: |
@@ -529,15 +596,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: Test
run: |

View File

@@ -63,10 +63,15 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: Generate test script
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,8 +19,8 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:
fail-fast: false

51
.github/workflows/pr-title-check.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -9,8 +9,8 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -160,7 +160,7 @@ jobs:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request'
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
uses: docker/login-action@v3
with:
registry: ghcr.io
@@ -169,6 +169,14 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
@@ -176,6 +184,7 @@ jobs:
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
@@ -185,6 +194,55 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*amd64*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
else
echo "main sha-$(git rev-parse --short HEAD)"
fi
}
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
for tag in $(resolve_tags); do
dst="${img_name}:${tag}"
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
done
}
export -f tag_and_push resolve_tags
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
done
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
@@ -251,6 +309,14 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
@@ -275,6 +341,24 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:

View File

@@ -61,8 +61,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
.run
*.iml
dist/
!proxy/web/dist/
bin/
.env
conf.json

View File

@@ -106,6 +106,26 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -120,6 +140,40 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-proxy
dir: proxy/cmd/proxy
env: [CGO_ENABLED=0]
binary: netbird-proxy
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -132,18 +186,22 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
id: netbird-deb
license: BSD-3-Clause
id: netbird_deb
bindir: /usr/bin
builds:
- netbird
formats:
- deb
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
@@ -151,16 +209,19 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
id: netbird-rpm
license: BSD-3-Clause
id: netbird_rpm
bindir: /usr/bin
builds:
- netbird
formats:
- rpm
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
@@ -520,6 +581,104 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -598,6 +757,18 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -675,6 +846,43 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews:
- ids:
- default
@@ -695,7 +903,7 @@ brews:
uploads:
- name: debian
ids:
- netbird-deb
- netbird_deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -703,7 +911,7 @@ uploads:
- name: yum
ids:
- netbird-rpm
- netbird_rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -61,7 +61,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird-ui-deb
id: netbird_ui_deb
package_name: netbird-ui
builds:
- netbird-ui
@@ -80,7 +80,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird-ui-rpm
id: netbird_ui_rpm
package_name: netbird-ui
builds:
- netbird-ui
@@ -95,11 +95,14 @@ nfpms:
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
uploads:
- name: debian
ids:
- netbird-ui-deb
- netbird_ui_deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -107,7 +110,7 @@ uploads:
- name: yum
ids:
- netbird-ui-rpm
- netbird_ui_rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -1,7 +1,7 @@
## Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance

View File

@@ -1,4 +1,4 @@
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
BSD 3-Clause License

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features
@@ -126,6 +126,7 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2
FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -17,8 +17,7 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,8 +23,7 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -203,10 +203,11 @@ func (c *Client) PeersList() *PeerInfoArray {
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
for n, p := range fullStatus.Peers {
pi := PeerInfo{
p.IP,
p.FQDN,
p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
IP: p.IP,
IPv6: p.IPv6,
FQDN: p.FQDN,
ConnStatus: int(p.ConnStatus),
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi
}
@@ -237,43 +238,84 @@ func (c *Client) Networks() *NetworkArray {
return nil
}
routesMap := routeManager.GetClientRoutesWithNetID()
v6Merged := route.V6ExitMergeSet(routesMap)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
networkArray := &NetworkArray{
items: make([]Network, 0),
}
resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() {
for id, routes := range routesMap {
if len(routes) == 0 {
continue
}
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
if _, skip := v6Merged[id]; skip {
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
if network == nil {
continue
}
networkArray.Add(network)
networkArray.Add(*network)
}
return networkArray
}
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.findBestRoutePeer(routes)
if err != nil {
log.Errorf("could not get peer info for route %s: %v", id, err)
return nil
}
network := &Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: selected,
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
}
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
network.Network = "0.0.0.0/0, ::/0"
}
return network
}
// findBestRoutePeer returns the peer actively routing traffic for the given
// HA route group. Falls back to the first connected peer, then the first peer.
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
netStr := routes[0].Network.String()
fullStatus := c.recorder.GetFullStatus()
for _, p := range fullStatus.Peers {
if _, ok := p.GetRoutes()[netStr]; ok {
return p, nil
}
}
for _, r := range routes {
p, err := c.recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if p.ConnStatus == peer.StatusConnected {
return p, nil
}
}
return c.recorder.GetPeer(routes[0].Peer)
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()

View File

@@ -1,10 +1,19 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
var (
// EnvKeyNBForceRelay Exported for Android java client
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
)
// EnvList wraps a Go map for export to Java

View File

@@ -2,11 +2,21 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
// Connection status constants exported via gomobile.
const (
ConnStatusIdle = int(peer.StatusIdle)
ConnStatusConnecting = int(peer.StatusConnecting)
ConnStatusConnected = int(peer.StatusConnected)
)
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
IPv6 string
FQDN string
ConnStatus string // Todo replace to enum
ConnStatus int
Routes PeerRoutes
}

View File

@@ -307,6 +307,24 @@ func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// GetDisableIPv6 reads disable IPv6 setting from config file
func (p *Preferences) GetDisableIPv6() (bool, error) {
if p.configInput.DisableIPv6 != nil {
return *p.configInput.DisableIPv6, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableIPv6, err
}
// SetDisableIPv6 stores the given value and waits for commit
func (p *Preferences) SetDisableIPv6(disable bool) {
p.configInput.DisableIPv6 = &disable
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)

View File

@@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager,
netID := route.NetID(id)
routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id)
routesMap := manager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("%s with ids: %v", operationName, routes)
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"regexp"
"slices"
"strconv"
"strings"
)
@@ -26,8 +27,9 @@ type Anonymizer struct {
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
@@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
// Handle CIDR notation (e.g. "2001:db8::/32")
if prefix, err := netip.ParsePrefix(ip); err == nil {
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
}
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip

View File

@@ -13,7 +13,7 @@ import (
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
@@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
expect: "Access attempted from 2001:db8:ffff::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
expect: "Access attempted from [2001:db8:ffff::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
},
}

View File

@@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
@@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
}
cmd.Println("netbird down")
time.Sleep(1 * time.Second)
@@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
}
cmd.Println("netbird up")
time.Sleep(3 * time.Second)
@@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
}
cmd.Println("netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
} else {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())

287
client/cmd/expose.go Normal file
View File

@@ -0,0 +1,287 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
exposeExternalPort uint16
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: ` netbird expose --with-password safe-pass 8080
netbird expose --protocol tcp 5432
netbird expose --protocol tcp --with-external-port 5433 5432
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
}
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
func isClusterProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp", "tls":
return true
default:
return false
}
}
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
// where domain display doesn't apply. TLS uses SNI so it has a domain.
func isPortBasedProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp":
return true
default:
return false
}
}
// extractPort returns the port portion of a URL like "tcp://host:12345", or
// falls back to the given default formatted as a string.
func extractPort(serviceURL string, fallback uint16) string {
u := serviceURL
if idx := strings.Index(u, "://"); idx != -1 {
u = u[idx+3:]
}
if i := strings.LastIndex(u, ":"); i != -1 {
if p := u[i+1:]; p != "" {
return p
}
}
return strconv.FormatUint(uint64(fallback), 10)
}
// resolveExternalPort returns the effective external port, defaulting to the target port.
func resolveExternalPort(targetPort uint64) uint16 {
if exposeExternalPort != 0 {
return exposeExternalPort
}
return uint16(targetPort)
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
}
if isClusterProtocol(exposeProtocol) {
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
}
} else if cmd.Flags().Changed("with-external-port") {
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
switch strings.ToLower(exposeProtocol) {
case "http", "https", "tcp", "udp", "tls":
return true
default:
return false
}
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
req := &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
}
if isClusterProtocol(exposeProtocol) {
req.ListenPort = uint32(resolveExternalPort(port))
}
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
p, err := expose.ParseProtocolType(exposeProtocol)
if err != nil {
return 0, fmt.Errorf("invalid protocol: %w", err)
}
switch p {
case expose.ProtocolHTTP:
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case expose.ProtocolHTTPS:
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case expose.ProtocolTCP:
return proto.ExposeProtocol_EXPOSE_TCP, nil
case expose.ProtocolUDP:
return proto.ExposeProtocol_EXPOSE_UDP, nil
case expose.ProtocolTLS:
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unhandled protocol type: %d", p)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
if !ok {
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
printExposeReady(cmd, ready.Ready, port)
return nil
}
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", r.ServiceName)
if r.ServiceUrl != "" {
cmd.Printf(" URL: %s\n", r.ServiceUrl)
}
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
cmd.Printf(" Domain: %s\n", r.Domain)
}
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Internal: %d\n", port)
if isClusterProtocol(exposeProtocol) {
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
}
if r.PortAutoAssigned && exposeExternalPort != 0 {
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
}
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
}
defer authClient.Close()
needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("check login required: %v", err)
}
jwtToken := ""

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -80,6 +81,15 @@ var (
Short: "",
Long: "",
SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
}
)
@@ -144,6 +154,7 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil
}
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -41,7 +41,7 @@ func init() {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")

View File

@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
SetFlagsFromEnvVars(rootCmd)
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())

View File

@@ -119,6 +119,10 @@ var installCmd = &cobra.Command{
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
@@ -136,6 +140,10 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
cmd.Println("NetBird service has been installed")
return nil
},
@@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
@@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re
return fmt.Errorf("install service with new config: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
if wasRunning {
cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil {

View File

@@ -0,0 +1,201 @@
//go:build !ios && !android
package cmd
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/util"
)
const serviceParamsFile = "service.json"
// serviceParams holds install-time service parameters that persist across
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
type serviceParams struct {
LogLevel string `json:"log_level"`
DaemonAddr string `json:"daemon_addr"`
ManagementURL string `json:"management_url,omitempty"`
ConfigPath string `json:"config_path,omitempty"`
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
// serviceParamsPath returns the path to the service params file.
func serviceParamsPath() string {
return filepath.Join(configs.StateDir, serviceParamsFile)
}
// loadServiceParams reads saved service parameters from disk.
// Returns nil with no error if the file does not exist.
func loadServiceParams() (*serviceParams, error) {
path := serviceParamsPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil //nolint:nilnil
}
return nil, fmt.Errorf("read service params %s: %w", path, err)
}
var params serviceParams
if err := json.Unmarshal(data, &params); err != nil {
return nil, fmt.Errorf("parse service params %s: %w", path, err)
}
return &params, nil
}
// saveServiceParams writes current service parameters to disk atomically
// with restricted permissions.
func saveServiceParams(params *serviceParams) error {
path := serviceParamsPath()
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
return fmt.Errorf("save service params: %w", err)
}
return nil
}
// currentServiceParams captures the current state of all package-level
// variables into a serviceParams struct.
func currentServiceParams() *serviceParams {
params := &serviceParams{
LogLevel: logLevel,
DaemonAddr: daemonAddr,
ManagementURL: managementURL,
ConfigPath: configPath,
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
return params
}
// loadAndApplyServiceParams loads saved params from disk and applies them
// to any flags that were not explicitly set.
func loadAndApplyServiceParams(cmd *cobra.Command) error {
params, err := loadServiceParams()
if err != nil {
return err
}
applyServiceParams(cmd, params)
return nil
}
// applyServiceParams merges saved parameters into package-level variables
// for any flag that was not explicitly set by the user (via CLI or env var).
// Flags that were Changed() are left untouched.
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
if params == nil {
return
}
// For fields with non-empty defaults (log-level, daemon-addr), keep the
// != "" guard so that an older service.json missing the field doesn't
// clobber the default with an empty string.
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
logLevel = params.LogLevel
}
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
daemonAddr = params.DaemonAddr
}
// For optional fields where empty means "use default", always apply so
// that an explicit clear (--management-url "") persists across reinstalls.
if !rootCmd.PersistentFlags().Changed("management-url") {
managementURL = params.ManagementURL
}
if !rootCmd.PersistentFlags().Changed("config") {
configPath = params.ConfigPath
}
if !rootCmd.PersistentFlags().Changed("log-file") {
logFiles = params.LogFiles
}
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
profilesDisabled = params.DisableProfiles
}
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
updateSettingsDisabled = params.DisableUpdateSettings
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict
serviceEnvVars = envMapToSlice(merged)
}
var resetParamsCmd = &cobra.Command{
Use: "reset-params",
Short: "Remove saved service install parameters",
Long: "Removes the saved service.json file so the next install uses default parameters.",
RunE: func(cmd *cobra.Command, args []string) error {
path := serviceParamsPath()
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
cmd.Println("No saved service parameters found")
return nil
}
return fmt.Errorf("remove service params: %w", err)
}
cmd.Printf("Removed saved service parameters (%s)\n", path)
return nil
},
}
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
func envMapToSlice(m map[string]string) []string {
s := make([]string, 0, len(m))
for k, v := range m {
s = append(s, k+"="+v)
}
return s
}

View File

@@ -0,0 +1,523 @@
//go:build !ios && !android
package cmd
import (
"encoding/json"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/configs"
)
func TestServiceParamsPath(t *testing.T) {
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params := &serviceParams{
LogLevel: "debug",
DaemonAddr: "unix:///var/run/netbird.sock",
ManagementURL: "https://my.server.com",
ConfigPath: "/etc/netbird/config.json",
LogFiles: []string{"/var/log/netbird/client.log", "console"},
DisableProfiles: true,
DisableUpdateSettings: false,
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
}
err := saveServiceParams(params)
require.NoError(t, err)
// Verify the file exists and is valid JSON.
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
require.NoError(t, err)
assert.True(t, json.Valid(data))
loaded, err := loadServiceParams()
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, params.LogLevel, loaded.LogLevel)
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
assert.Equal(t, params.LogFiles, loaded.LogFiles)
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
}
func TestLoadServiceParams_FileNotExists(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params, err := loadServiceParams()
assert.NoError(t, err)
assert.Nil(t, params)
}
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
require.NoError(t, err)
params, err := loadServiceParams()
assert.Error(t, err)
assert.Nil(t, params)
}
func TestCurrentServiceParams(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
logLevel = "trace"
daemonAddr = "tcp://127.0.0.1:9999"
managementURL = "https://mgmt.example.com"
configPath = "/tmp/test-config.json"
logFiles = []string{"/tmp/test.log"}
profilesDisabled = true
updateSettingsDisabled = true
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
params := currentServiceParams()
assert.Equal(t, "trace", params.LogLevel)
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
assert.True(t, params.DisableProfiles)
assert.True(t, params.DisableUpdateSettings)
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
}
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
// Reset all flags to defaults.
logLevel = "info"
daemonAddr = "unix:///var/run/netbird.sock"
managementURL = ""
configPath = "/etc/netbird/config.json"
logFiles = []string{"/var/log/netbird/client.log"}
profilesDisabled = false
updateSettingsDisabled = false
serviceEnvVars = nil
// Reset Changed state on all relevant flags.
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Simulate user explicitly setting --log-level via CLI.
logLevel = "warn"
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
saved := &serviceParams{
LogLevel: "debug",
DaemonAddr: "tcp://127.0.0.1:5555",
ManagementURL: "https://saved.example.com",
ConfigPath: "/saved/config.json",
LogFiles: []string{"/saved/client.log"},
DisableProfiles: true,
DisableUpdateSettings: true,
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
// log-level was Changed, so it should keep "warn", not use saved "debug".
assert.Equal(t, "warn", logLevel)
// All other fields were not Changed, so they should use saved values.
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
assert.Equal(t, "https://saved.example.com", managementURL)
assert.Equal(t, "/saved/config.json", configPath)
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
assert.True(t, profilesDisabled)
assert.True(t, updateSettingsDisabled)
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
}
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
t.Cleanup(func() {
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
})
// Simulate current state where booleans are true (e.g. set by previous install).
profilesDisabled = true
updateSettingsDisabled = true
// Reset Changed state so flags appear unset.
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Saved params have both as false.
saved := &serviceParams{
DisableProfiles: false,
DisableUpdateSettings: false,
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.False(t, profilesDisabled, "saved false should override current true")
assert.False(t, updateSettingsDisabled, "saved false should override current true")
}
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
origManagementURL := managementURL
t.Cleanup(func() { managementURL = origManagementURL })
managementURL = "https://leftover.example.com"
// Simulate saved params where management URL was explicitly cleared.
saved := &serviceParams{
LogLevel: "info",
DaemonAddr: "unix:///var/run/netbird.sock",
// ManagementURL intentionally empty: was cleared with --management-url "".
}
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
}
func TestApplyServiceParams_NilParams(t *testing.T) {
origLogLevel := logLevel
t.Cleanup(func() { logLevel = origLogLevel })
logLevel = "info"
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
// Should be a no-op.
applyServiceParams(cmd, nil)
assert.Equal(t, "info", logLevel)
}
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Set up a command with --service-env marked as Changed.
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
saved := &serviceParams{
ServiceEnvVars: map[string]string{
"SAVED": "val",
"OVERLAP": "saved",
},
}
applyServiceEnvParams(cmd, saved)
// Parse result for easier assertion.
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, "yes", result["EXPLICIT"])
assert.Equal(t, "val", result["SAVED"])
// Explicit wins on conflict.
assert.Equal(t, "explicit", result["OVERLAP"])
}
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
serviceEnvVars = nil
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
saved := &serviceParams{
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
}
applyServiceEnvParams(cmd, saved)
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
// Collect all JSON field names from the serviceParams struct.
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
// Collect field names referenced in currentServiceParams and applyServiceParams.
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
// applyServiceEnvParams handles ServiceEnvVars indirectly.
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
for k, v := range applyEnvFields {
applyFields[k] = v
}
for _, field := range structFields {
assert.Contains(t, currentFields, field,
"serviceParams field %q is not captured in currentServiceParams()", field)
assert.Contains(t, applyFields, field,
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
}
}
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
// it flows through newSVCConfig() EnvVars, not CLI args.
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields)
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
require.NoError(t, err)
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
fieldsNotInArgs := map[string]bool{
"ServiceEnvVars": true,
}
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
// Forward: every struct field must appear in buildServiceArguments.
for _, field := range structFields {
if fieldsNotInArgs[field] {
continue
}
globalVar := fieldToGlobalVar(field)
assert.Contains(t, buildFields, globalVar,
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
}
// Reverse: every service-related global used in buildServiceArguments must
// have a corresponding serviceParams field. This catches a developer adding
// a new flag to buildServiceArguments without adding it to the struct.
globalToField := make(map[string]string, len(structFields))
for _, field := range structFields {
globalToField[fieldToGlobalVar(field)] = field
}
// Identifiers in buildServiceArguments that are not service params
// (builtins, boilerplate, loop variables).
nonParamGlobals := map[string]bool{
"args": true, "append": true, "string": true, "_": true,
"logFile": true, // range variable over logFiles
}
for ref := range buildFields {
if nonParamGlobals[ref] {
continue
}
_, inStruct := globalToField[ref]
assert.True(t, inStruct,
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
}
}
// extractStructJSONFields returns field names from a named struct type.
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
t.Helper()
var fields []string
ast.Inspect(file, func(n ast.Node) bool {
ts, ok := n.(*ast.TypeSpec)
if !ok || ts.Name.Name != structName {
return true
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
return false
}
for _, f := range st.Fields.List {
if len(f.Names) > 0 {
fields = append(fields, f.Names[0].Name)
}
}
return false
})
return fields
}
// extractFuncFieldRefs returns which of the given field names appear inside the
// named function, either as selector expressions (params.FieldName) or as
// composite literal keys (&serviceParams{FieldName: ...}).
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
t.Helper()
fieldSet := make(map[string]bool, len(fields))
for _, f := range fields {
fieldSet[f] = true
}
found := make(map[string]bool)
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
ast.Inspect(fn.Body, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.SelectorExpr:
if fieldSet[v.Sel.Name] {
found[v.Sel.Name] = true
}
case *ast.KeyValueExpr:
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
found[ident.Name] = true
}
}
return true
})
return found
}
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
t.Helper()
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
refs := make(map[string]bool)
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
refs[ident.Name] = true
}
return true
})
return refs
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
fn, ok := decl.(*ast.FuncDecl)
if ok && fn.Name.Name == name {
return fn
}
}
return nil
}
// fieldToGlobalVar maps serviceParams field names to the package-level variable
// names used in buildServiceArguments and applyServiceParams.
func fieldToGlobalVar(field string) string {
m := map[string]string{
"LogLevel": "logLevel",
"DaemonAddr": "daemonAddr",
"ManagementURL": "managementURL",
"ConfigPath": "configPath",
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {
return v
}
// Default: lowercase first letter.
return strings.ToLower(field[:1]) + field[1:]
}
func TestEnvMapToSlice(t *testing.T) {
m := map[string]string{"A": "1", "B": "2"}
s := envMapToSlice(m)
assert.Len(t, s, 2)
assert.Contains(t, s, "A=1")
assert.Contains(t, s, "B=2")
}
func TestEnvMapToSlice_Empty(t *testing.T) {
s := envMapToSlice(map[string]string{})
assert.Empty(t, s)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -13,6 +15,22 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
)
var (

View File

@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
)
var (

View File

@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
}
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
func normalizeLocalHost(host string) string {
if host == "*" {
return "0.0.0.0"
return ""
}
return host
}

View File

@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
{
name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080",
expectedLocal: ":8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
description: "Wildcard * should bind to all interfaces (dual-stack)",
},
{
name: "wildcard for port only",

View File

@@ -20,6 +20,7 @@ import (
var (
detailFlag bool
ipv4Flag bool
ipv6Flag bool
jsonFlag bool
yamlFlag bool
ipsFilter []string
@@ -28,6 +29,7 @@ var (
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
checkFlag string
)
var statusCmd = &cobra.Command{
@@ -44,11 +46,13 @@ func init() {
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -56,6 +60,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
if checkFlag != "" {
return runHealthCheck(cmd)
}
err := parseFilters()
if err != nil {
return err
@@ -68,15 +76,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, false)
resp, err := getStatus(ctx, true, false)
if err != nil {
return err
}
status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired)
if needsAuth && !jsonFlag && !yamlFlag {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -93,13 +103,31 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
if ipv6Flag {
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
if ipv6 != "" {
cmd.Print(parseInterfaceIP(ipv6))
}
return nil
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
DaemonVersion: resp.GetDaemonVersion(),
DaemonStatus: nbstatus.ParseDaemonStatus(status),
StatusFilter: statusFilter,
PrefixNamesFilter: prefixNamesFilter,
PrefixNamesFilterMap: prefixNamesFilterMap,
IPsFilter: ipsFilterMap,
ConnectionTypeFilter: connectionTypeFilter,
ProfileName: profName,
})
var statusOutputString string
switch {
case detailFlag:
@@ -121,7 +149,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
@@ -131,7 +159,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -185,6 +213,83 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func runHealthCheck(cmd *cobra.Command) error {
check := strings.ToLower(checkFlag)
switch check {
case "live", "ready", "startup":
default:
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
}
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
isStartup := check == "startup"
resp, err := getStatus(ctx, isStartup, false)
if err != nil {
return err
}
switch check {
case "live":
return nil
case "ready":
return checkReadiness(resp)
case "startup":
return checkStartup(resp)
default:
return nil
}
}
func checkReadiness(resp *proto.StatusResponse) error {
daemonStatus := internal.StatusType(resp.GetStatus())
switch daemonStatus {
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
return nil
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
default:
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
}
}
func checkStartup(resp *proto.StatusResponse) error {
fullStatus := resp.GetFullStatus()
if fullStatus == nil {
return fmt.Errorf("startup check: no full status available")
}
if !fullStatus.GetManagementState().GetConnected() {
return fmt.Errorf("startup check: management not connected")
}
if !fullStatus.GetSignalState().GetConnected() {
return fmt.Errorf("startup check: signal not connected")
}
var relayCount, relaysConnected int
for _, r := range fullStatus.GetRelays() {
uri := r.GetURI()
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
continue
}
relayCount++
if r.GetAvailable() {
relaysConnected++
}
}
if relayCount > 0 && relaysConnected == 0 {
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
}
return nil
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {

View File

@@ -8,6 +8,7 @@ const (
disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
disableIPv6Flag = "disable-ipv6"
)
var (
@@ -17,6 +18,7 @@ var (
disableFirewall bool
blockLANAccess bool
blockInbound bool
disableIPv6 bool
)
func init() {
@@ -39,4 +41,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
}

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r, false)
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
@@ -430,6 +430,10 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.BlockInbound = &blockInbound
}
if cmd.Flag(disableIPv6Flag).Changed {
req.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled
}
@@ -547,6 +551,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.BlockInbound = &blockInbound
}
if cmd.Flag(disableIPv6Flag).Changed {
ic.DisableIPv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
@@ -661,6 +669,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockInbound = &blockInbound
}
if cmd.Flag(disableIPv6Flag).Changed {
loginRequest.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}

View File

@@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/util"
)

View File

@@ -14,6 +14,7 @@ import (
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -31,6 +33,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -69,6 +79,20 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// DisableIPv6 disables IPv6 overlay addressing
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the WireGuard interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
}
// validateCredentials checks that exactly one credential type is provided
@@ -100,6 +124,12 @@ func New(opts Options) (*Client, error) {
return nil, err
}
if opts.MTU != nil {
if err := iface.ValidateMTU(*opts.MTU); err != nil {
return nil, fmt.Errorf("invalid MTU: %w", err)
}
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -128,15 +158,25 @@ func New(opts Options) (*Client, error) {
}
}
var err error
var parsedLabels domain.List
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
return nil, fmt.Errorf("invalid dns labels: %w", err)
}
t := true
var config *profilemanager.Config
var err error
input := profilemanager.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -156,6 +196,7 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil
}
@@ -177,6 +218,7 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
@@ -186,10 +228,7 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client := internal.NewConnectClient(ctx, c.config, c.recorder)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -339,17 +378,38 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
// It returns an ExposeSession. Call Wait on the session to keep it alive.
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
mgr := engine.GetExposeManager()
if mgr == nil {
return nil, fmt.Errorf("expose manager not available")
}
resp, err := mgr.Expose(ctx, req)
if err != nil {
return nil, fmt.Errorf("expose: %w", err)
}
return &ExposeSession{
Domain: resp.Domain,
ServiceName: resp.ServiceName,
ServiceURL: resp.ServiceURL,
mgr: mgr,
}, nil
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
@@ -357,7 +417,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
}
}
return recorder.GetFullStatus(), nil
return c.recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.

45
client/embed/expose.go Normal file
View File

@@ -0,0 +1,45 @@
package embed
import (
"context"
"errors"
"github.com/netbirdio/netbird/client/internal/expose"
)
const (
// ExposeProtocolHTTP exposes the service as HTTP.
ExposeProtocolHTTP = expose.ProtocolHTTP
// ExposeProtocolHTTPS exposes the service as HTTPS.
ExposeProtocolHTTPS = expose.ProtocolHTTPS
// ExposeProtocolTCP exposes the service as TCP.
ExposeProtocolTCP = expose.ProtocolTCP
// ExposeProtocolUDP exposes the service as UDP.
ExposeProtocolUDP = expose.ProtocolUDP
// ExposeProtocolTLS exposes the service as TLS.
ExposeProtocolTLS = expose.ProtocolTLS
)
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
type ExposeRequest = expose.Request
// ExposeProtocolType represents the protocol used for exposing a service.
type ExposeProtocolType = expose.ProtocolType
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
type ExposeSession struct {
Domain string
ServiceName string
ServiceURL string
mgr *expose.Manager
}
// Wait blocks while keeping the expose session alive.
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
func (s *ExposeSession) Wait(ctx context.Context) error {
if s == nil || s.mgr == nil {
return errors.New("expose session is not initialized")
}
return s.mgr.KeepAlive(ctx, s.Domain)
}

View File

@@ -36,6 +36,7 @@ type aclManager struct {
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager
}
@@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil
}
@@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering(
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
@@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering(
ip: ip.String(),
chain: chain,
specs: specs,
v6: m.v6,
}}, nil
}
@@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering(
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
v6: m.v6,
}
m.updateState()
@@ -376,8 +384,13 @@ func (m *aclManager) updateState() {
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -385,6 +398,15 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
@@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -23,9 +23,15 @@ type Manager struct {
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -57,9 +63,39 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil
}
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
func (m *Manager) hasIPv6() bool {
return m.ipv6Client != nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
@@ -79,12 +115,38 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router after acl init failure: %v", err)
}
return fmt.Errorf("acl manager init: %w", err)
}
if m.hasIPv6() {
if err := m.router6.init(stateManager); err != nil {
if err := m.aclMgr.Reset(); err != nil {
log.Warnf("rollback acl after v6 router init failure: %v", err)
}
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router after v6 router init failure: %v", err)
}
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclMgr6.init(stateManager); err != nil {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router after v6 acl init failure: %v", err)
}
if err := m.aclMgr.Reset(); err != nil {
log.Warnf("rollback acl after v6 acl init failure: %v", err)
}
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router after v6 acl init failure: %v", err)
}
return fmt.Errorf("v6 acl manager init: %w", err)
}
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
}
// persist early to ensure cleanup of chains
@@ -112,7 +174,13 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
@@ -126,25 +194,48 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule)
}
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule)
}
@@ -160,18 +251,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair)
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
if err := m.router.RemoveNatRule(pair); err != nil {
return err
}
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
return fmt.Errorf("remove v6 NAT rule: %w", err)
}
}
return nil
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
}
// Reset firewall to the default state
@@ -185,6 +321,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -208,19 +353,16 @@ func (m *Manager) AllowNetbird() error {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionAccept,
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
var merr *multierror.Error
if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err))
}
return nil
if m.hasIPv6() {
if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// Flush doesn't need to be implemented for this manager
@@ -250,6 +392,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
}
@@ -258,6 +403,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule)
}
@@ -266,7 +414,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
@@ -274,6 +441,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
@@ -282,6 +452,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
@@ -318,6 +491,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !m.rawSupported {
return fmt.Errorf("raw table not available")
}
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
@@ -375,12 +552,16 @@ func (m *Manager) initNoTrackChain() error {
return fmt.Errorf("add prerouting jump rule: %w", err)
}
m.rawSupported = true
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
if !m.rawSupported {
return nil
}
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
@@ -401,6 +582,7 @@ func (m *Manager) cleanupNoTrackChain() error {
return fmt.Errorf("clear and delete chain: %w", err)
}
m.rawSupported = false
return nil
}

View File

@@ -52,8 +52,10 @@ const (
snatSuffix = "_snat"
fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
)
type ruleInfo struct {
@@ -84,6 +86,7 @@ type router struct {
wgIface iFaceMapper
legacyManagement bool
mtu uint16
v6 bool
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
@@ -95,6 +98,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
rules: make(map[string][]string),
wgIface: wgIface,
mtu: mtu,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
@@ -184,6 +188,11 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil
}
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID()
@@ -423,6 +432,12 @@ func (r *router) createContainers() error {
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
@@ -529,9 +544,12 @@ func (r *router) addPostroutingRules() error {
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
@@ -716,8 +734,13 @@ func (r *router) updateState() {
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
}
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -845,7 +868,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
@@ -872,7 +895,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
@@ -889,11 +912,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
return []string{"-m", "set", matchSet, name, direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
@@ -904,19 +928,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
if err := r.addPrefixToIPSet(name, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
log.Debugf("updated set %s with prefixes %v", name, prefixes)
}
return nberrors.FormatErrorOrNil(merr)
@@ -932,7 +952,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
@@ -990,10 +1010,22 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
// to avoid collisions since ipsets are global in the kernel.
func (r *router) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -9,6 +9,7 @@ type Rule struct {
mangleSpecs []string
ip string
chain string
v6 bool
}
// GetRuleID returns the rule id

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -37,6 +39,12 @@ type ShutdownState struct {
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
}
func (s *ShutdownState) Name() string {
@@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
// Clean up v6 state even if the current run has no IPv6.
// The previous run may have left ip6tables rules behind.
if !ipt.hasIPv6() {
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
log.Warnf("failed to create v6 components for cleanup: %v", err)
}
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}
if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}

View File

@@ -33,15 +33,12 @@ const (
const flushError = "flush: %w"
var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
af addrFamily
workTable *nftables.Table
chainInputRules *nftables.Chain
@@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
@@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
}
@@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering(
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Offset: m.af.protoOffset,
Len: uint32(1),
})
protoData, err := protoToInt(proto)
protoData, err := m.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
@@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering(
})
}
rawIP := ip.To4()
rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrOffset := uint32(12)
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset,
Len: 4,
Offset: m.af.srcAddrOffset,
Len: m.af.addrLen,
},
)
// add individual IP for match if no ipset defined
@@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ip.To4()
rawIP := ipToBytes(ip, m.af)
if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
@@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
Name: name,
Table: table,
Dynamic: true,
KeyType: nftables.TypeIPAddr,
KeyType: m.af.setKeyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
@@ -707,15 +702,12 @@ func ifname(n string) []byte {
return b
}
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
// ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
}

View File

@@ -0,0 +1,81 @@
package nftables
import (
"fmt"
"net"
"github.com/google/nftables"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// addrFamily holds protocol-specific constants for nftables expression building.
type addrFamily struct {
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
protoOffset uint32
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
srcAddrOffset uint32
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
dstAddrOffset uint32
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
addrLen uint32
// totalBits is the address size in bits (32 for v4, 128 for v6)
totalBits int
// setKeyType is the nftables set data type for addresses
setKeyType nftables.SetDatatype
// tableFamily is the nftables table family
tableFamily nftables.TableFamily
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
icmpProto uint8
}
var (
// afIPv4 defines IPv4 header layout and nftables types.
afIPv4 = addrFamily{
protoOffset: 9,
srcAddrOffset: 12,
dstAddrOffset: 16,
addrLen: net.IPv4len,
totalBits: 8 * net.IPv4len,
setKeyType: nftables.TypeIPAddr,
tableFamily: nftables.TableFamilyIPv4,
icmpProto: unix.IPPROTO_ICMP,
}
// afIPv6 defines IPv6 header layout and nftables types.
afIPv6 = addrFamily{
protoOffset: 6,
srcAddrOffset: 8,
dstAddrOffset: 24,
addrLen: net.IPv6len,
totalBits: 8 * net.IPv6len,
setKeyType: nftables.TypeIP6Addr,
tableFamily: nftables.TableFamilyIPv6,
icmpProto: unix.IPPROTO_ICMPV6,
}
)
// familyForAddr returns the address family for the given IP.
func familyForAddr(is4 bool) addrFamily {
if is4 {
return afIPv4
}
return afIPv6
}
// protoNum converts a firewall protocol to the IP protocol number,
// using the correct ICMP variant for the address family.
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return af.icmpProto, nil
case firewall.ProtocolALL:
return 0, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}

View File

@@ -11,9 +11,11 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -49,8 +51,13 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
router *router
aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
}
@@ -62,7 +69,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
wgIface: wgIface,
}
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
tableName := getTableName()
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface, mtu)
@@ -75,11 +83,66 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil
}
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
if err != nil {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
}
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
if err := m.initFirewall(); err != nil {
return err
}
m.persistState(stateManager)
return nil
}
func (m *Manager) initFirewall() error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
@@ -90,20 +153,32 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
m.rollbackInit()
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
if m.hasIPv6() {
if err := m.initIPv6(); err != nil {
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
m.rollbackInit()
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
}
}
if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
}
return nil
}
// persistState saves the current interface state for potential recreation on restart.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
func (m *Manager) persistState(stateManager *statemanager.Manager) {
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
@@ -115,14 +190,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
}
return nil
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
log.Warnf("cleanup tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
log.Warnf("flush: %v", err)
}
}
// AddPeerFiltering rule to the firewall
@@ -141,12 +226,14 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
rawIP := ip.To4()
if rawIP == nil {
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
if ip.To4() != nil {
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
@@ -160,8 +247,11 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -172,14 +262,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash, so the rule exists in exactly one
// router. We check v4 first; if the key isn't there, try v6.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule)
}
@@ -195,17 +309,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes (DomainSet) need NAT in both tables since resolved IPs
// can be either v4 or v6.
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair)
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
if err := m.router.RemoveNatRule(pair); err != nil {
return err
}
if m.hasIPv6() && pair.Destination.IsSet() {
v6Pair := pair
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
return fmt.Errorf("remove v6 NAT rule: %w", err)
}
}
return nil
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
// which doesn't override DROP rules in external tables (e.g. firewalld).
// Should add passthrough rules to external chains (like the native mode router's
// addExternalChainsRules does) for both the netbird table family and inet tables.
// The netbird table itself is fine (routing chains already exist there), but
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
@@ -217,6 +377,11 @@ func (m *Manager) AllowNetbird() error {
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
@@ -226,7 +391,13 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
}
// Close closes the firewall manager
@@ -234,23 +405,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
func (m *Manager) cleanupNetbirdTables() error {
@@ -299,6 +478,12 @@ func (m *Manager) Flush() error {
return err
}
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
@@ -311,6 +496,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
}
@@ -319,6 +507,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule)
}
@@ -327,7 +518,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
@@ -335,6 +545,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
@@ -343,6 +556,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && localAddr.Is6() {
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
@@ -517,7 +733,11 @@ func (m *Manager) refreshNoTrackChains() error {
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
}
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
@@ -529,7 +749,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
}
}
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
err = m.rConn.Flush()
return table, err
}

View File

@@ -385,10 +385,134 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "failed to add DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
})
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
if _, err := exec.LookPath(bin); err != nil {
t.Skipf("%s not available on this system: %v", bin, err)
}
}
// Seed ip6 tables in the nft backend. Docker may not create them.
seedIp6tables(t)
ifaceMockV6 := &iFaceMock{
NameFunc: func() string { return "wt-test" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil), "close manager")
stdout, stderr := runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "add v6 route filtering rule")
err = manager.AddNatRule(fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
Masquerade: true,
})
require.NoError(t, err, "add v6 NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "add v6 DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
})
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
stdout, stderr = runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
}
func seedIp6tables(t *testing.T) {
t.Helper()
for _, tc := range []struct{ table, chain string }{
{"filter", "FORWARD"},
{"nat", "POSTROUTING"},
{"mangle", "FORWARD"},
} {
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
}
}
func runIp6tablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("ip6tables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "ip6tables-save failed")
return stdout.String(), stderr.String()
}
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
require.NotContains(t, stdout, "Table `nat' is incompatible",
"ip6tables-save: nat table incompatible. Full output: %s", stdout)
require.NotContains(t, stdout, "Table `mangle' is incompatible",
"ip6tables-save: mangle table incompatible. Full output: %s", stdout)
require.NotContains(t, stdout, "Table `filter' is incompatible",
"ip6tables-save: filter table incompatible. Full output: %s", stdout)
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")

View File

@@ -46,8 +46,10 @@ const (
dnatSuffix = "_dnat"
snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
@@ -72,6 +74,7 @@ type router struct {
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
@@ -84,6 +87,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
@@ -142,7 +146,7 @@ func (r *router) Reset() error {
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
Family: r.af.tableFamily,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
@@ -175,7 +179,7 @@ func (r *router) removeNatPreroutingRules() error {
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
}
@@ -407,7 +411,7 @@ func (r *router) AddRouteFiltering(
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto)
protoNum, err := r.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
@@ -467,7 +471,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return getIpSetExprs(ref, isSource)
return r.getIpSetExprs(ref, isSource)
}
func (r *router) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) hasDNATRule(id string) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -483,7 +504,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -511,10 +537,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
KeyType: r.af.setKeyType,
}
elements := convertPrefixesToSet(prefixes)
elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
@@ -547,23 +573,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
@@ -573,10 +593,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
masked := prefix.Masked()
if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
return netip.AddrFrom4(uint32ToBytes(lastIP))
// IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
}
// Utility function to convert netip.Addr to uint32.
@@ -660,13 +690,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -809,9 +858,12 @@ func (r *router) addPostroutingRules() {
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
exprsOut := []expr.Any{
&expr.Meta{
@@ -928,18 +980,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
@@ -1006,17 +1070,22 @@ func (r *router) acceptFilterTableRules() error {
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
// Try iptables first and fallback to nftables if iptables is not available.
// Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return r.acceptFilterRulesIptables(ipt)
if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
}
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1185,13 +1254,17 @@ func (r *router) removeFilterTableRules() error {
return nil
}
ipt, err := iptables.New()
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return r.removeAcceptFilterRulesIptables(ipt)
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
}
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
@@ -1258,7 +1331,7 @@ func (r *router) removeExternalChainsRules() error {
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
@@ -1282,8 +1355,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false
}
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false
}
@@ -1329,65 +1402,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("list rules: %w", err)
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
newRules[string(rule.UserData)] = rule
}
}
}
return nil
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1400,7 +1497,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
@@ -1463,7 +1560,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
@@ -1559,7 +1656,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
Family: r.af.tableFamily,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
@@ -1594,8 +1691,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
Offset: r.af.dstAddrOffset,
Len: r.af.addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
@@ -1629,20 +1726,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
@@ -1659,7 +1770,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
elements := convertPrefixesToSet(prefixes)
elements := r.convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
@@ -1681,7 +1792,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
return nil
}
protoNum, err := protoToInt(protocol)
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
@@ -1712,7 +1823,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
@@ -1725,7 +1840,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
@@ -1757,16 +1872,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
@@ -1785,45 +1909,44 @@ func (r *router) applyNetwork(
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
return r.applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset by default
offset := r.af.dstAddrOffset
if isSource {
// src offset
offset = 12
offset = r.af.srcAddrOffset
}
ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions
// unspecified address (/0) doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, 32)
mask := net.CIDRMask(ones, r.af.totalBits)
xor := make([]byte, r.af.addrLen)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
Len: r.af.addrLen,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Len: r.af.addrLen,
Mask: mask,
Xor: []byte{0, 0, 0, 0},
Xor: xor,
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
@@ -1906,13 +2029,12 @@ func getCtNewExprs() []expr.Any {
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
offset := r.af.dstAddrOffset
if isSource {
// src offset
offset = 12
offset = r.af.srcAddrOffset
}
return []expr.Any{
@@ -1920,7 +2042,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
Len: r.af.addrLen,
},
&expr.Lookup{
SourceRegister: 1,

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -89,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
testRouter := &router{af: afIPv4}
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -507,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTableIPv6()
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IPv6",
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
},
{
name: "Multiple IPv6 Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::/48"),
netip.MustParsePrefix("fe80::/10"),
},
},
{
name: "Overlapping IPv6",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("fd00::1/128"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
},
},
{
name: "Mixed prefix lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::1/128"),
netip.MustParsePrefix("fd00:abcd::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
require.NoError(t, err, "Failed to create IPv6 set")
require.NotNil(t, set)
assert.Equal(t, setName, set.Name)
assert.True(t, set.Interval)
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd && len(elem.Key) == 16 {
ip := netip.AddrFrom16([16]byte(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
r.conn.DelSet(set)
require.NoError(t, r.conn.Flush())
})
}
}
func createWorkTableIPv6() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
err = sConn.Flush()
return table, err
}
func deleteWorkTableIPv6() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
_ = sConn.Flush()
}
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
@@ -626,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
expectedProto, _ := afIPv4.protoNum(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
@@ -719,3 +851,189 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}
func TestCalculateLastIP(t *testing.T) {
tests := []struct {
prefix string
want string
}{
{"10.0.0.0/24", "10.0.0.255"},
{"10.0.0.0/32", "10.0.0.0"},
{"0.0.0.0/0", "255.255.255.255"},
{"192.168.1.0/28", "192.168.1.15"},
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
{"fd00::/128", "fd00::"},
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
}
for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
prefix := netip.MustParsePrefix(tt.prefix)
got := calculateLastIP(prefix)
assert.Equal(t, tt.want, got.String())
})
}
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),
}
elements := r.convertPrefixesToSet(prefixes)
// Each prefix produces 2 elements (start + end)
require.Len(t, elements, 4)
// fd00::/64 start
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
assert.False(t, elements[0].IntervalEnd)
// fd00::/64 end (fd00:0:0:1::, one past the last)
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
assert.True(t, elements[1].IntervalEnd)
// 2001:db8::1/128 start
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
assert.False(t, elements[2].IntervalEnd)
// 2001:db8::1/128 end (2001:db8::2)
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
assert.True(t, elements[3].IntervalEnd)
}

View File

@@ -3,12 +3,6 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,15 +1,14 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -26,47 +25,26 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()
if !isWindowsFirewallReachable() {
return nil
}
if !isFirewallRuleActive(firewallRuleName) {
return nil
var merr *multierror.Error
if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
}
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
if isFirewallRuleActive(firewallRuleName + "-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic
@@ -75,17 +53,33 @@ func (m *Manager) AllowNetbird() error {
return nil
}
if isFirewallRuleActive(firewallRuleName) {
return nil
if !isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
); err != nil {
return err
}
}
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6",
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+v6.String(),
); err != nil {
return err
}
}
return nil
}
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {

View File

@@ -1,8 +1,9 @@
package conntrack
import (
"fmt"
"net"
"net/netip"
"strconv"
"sync/atomic"
"time"
@@ -64,5 +65,7 @@ type ConnKey struct {
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
" → " +
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
}

View File

@@ -21,9 +21,10 @@ const (
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
MaxICMPPayloadLength = 48
)
// ICMPConnKey uniquely identifies an ICMP connection
@@ -74,32 +75,64 @@ func (info ICMPInfo) String() string {
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
// isErrorMessage returns true if this ICMP type carries original packet info.
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
// kept as a literal.
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
// ICMPv4 error types
if typ == layers.ICMPv4TypeDestinationUnreachable ||
typ == layers.ICMPv4TypeRedirect ||
typ == layers.ICMPv4TypeTimeExceeded ||
typ == layers.ICMPv4TypeParameterProblem {
return true
}
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
if typ == layers.ICMPv6TypeDestinationUnreachable ||
typ == layers.ICMPv6TypePacketTooBig ||
typ == layers.ICMPv6TypeParameterProblem {
return true
}
return false
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
if info.PayloadLen == 0 {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
version := (info.PayloadData[0] >> 4) & 0xF
var protocol uint8
var srcIP, dstIP net.IP
var transportData []byte
switch version {
case 4:
// 20-byte IPv4 header + 8-byte transport minimum
if info.PayloadLen < 28 {
return ""
}
protocol = info.PayloadData[9]
srcIP = net.IP(info.PayloadData[12:16])
dstIP = net.IP(info.PayloadData[16:20])
transportData = info.PayloadData[20:]
case 6:
// 40-byte IPv6 header + 8-byte transport minimum
if info.PayloadLen < 48 {
return ""
}
// Next Header field in IPv6 header
protocol = info.PayloadData[6]
srcIP = net.IP(info.PayloadData[8:24])
dstIP = net.IP(info.PayloadData[24:40])
transportData = info.PayloadData[40:]
default:
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
@@ -247,9 +280,10 @@ func (t *ICMPTracker) track(
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
return false
}
@@ -301,6 +335,13 @@ func (t *ICMPTracker) cleanup() {
}
}
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
if ip.Is6() {
return nftypes.ICMPv6
}
return nftypes.ICMP
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.tickerCancel()
@@ -316,7 +357,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
Protocol: icmpProtocolForAddr(conn.SourceIP),
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
@@ -334,7 +375,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
Protocol: icmpProtocolForAddr(srcIP),
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,

View File

@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
if exists && !conn.IsSupersededBy(flags) {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
}
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsTombstone() {
if !exists || conn.IsSupersededBy(flags) {
return false
}

View File

@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
})
}
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -12,11 +13,13 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -24,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,8 +35,10 @@ import (
const (
layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
ipv4TCPHeaderMinSize = 40
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
ipv6TCPHeaderMinSize = 60
)
// serviceKey represents a protocol/port combination for netstack service registry
@@ -89,6 +95,7 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -132,9 +139,10 @@ type Manager struct {
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValue uint16
mssClampEnabled bool
mtu uint16
mssClampValueIPv4 uint16
mssClampValueIPv6 uint16
mssClampEnabled bool
}
// decoder for packages
@@ -147,11 +155,28 @@ type decoder struct {
icmp4 layers.ICMPv4
icmp6 layers.ICMPv6
decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser
parser4 *gopacket.DecodingLayerParser
parser6 *gopacket.DecodingLayerParser
dnatOrigPort uint16
}
// decodePacket decodes packet data using the appropriate parser based on IP version.
func (d *decoder) decodePacket(data []byte) error {
if len(data) == 0 {
return errors.New("empty packet")
}
version := data[0] >> 4
switch version {
case 4:
return d.parser4.DecodeLayers(data, &d.decoded)
case 6:
return d.parser6.DecodeLayers(data, &d.decoded)
default:
return fmt.Errorf("unknown IP version %d", version)
}
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
@@ -209,11 +234,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d
},
},
@@ -229,6 +260,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -238,7 +270,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if !disableMSSClamping {
m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
@@ -265,9 +298,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
if v6 := iface.Address().IPv6Net; v6.IsValid() {
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
}
rule, err := m.addRouteFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
sources,
firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL,
nil,
@@ -275,7 +313,22 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
firewall.ActionDrop,
)
if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err)
return nil, fmt.Errorf("block wg v4 net: %w", err)
}
if v6Net := iface.Address().IPv6Net; v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
if _, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
); err != nil {
return nil, fmt.Errorf("block wg v6 net: %w", err)
}
}
// TODO: Block networks that we're a client of
@@ -480,15 +533,19 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleID := uuid.New().String()
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
id: string(ruleKey),
mgmtId: id,
sources: sources,
dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
srcPort: sPort,
dstPort: dPort,
action: action,
@@ -499,6 +556,7 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -515,15 +573,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleID := rule.ID()
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID)
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -570,6 +633,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
@@ -600,11 +697,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
destinations := matches[0].destinations
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
destinations = append(destinations, prefixes...)
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
@@ -643,7 +736,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
return false
}
@@ -724,12 +817,28 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
var mssClampValue uint16
var ipHeaderSize int
switch d.decoded[0] {
case layers.LayerTypeIPv4:
mssClampValue = m.mssClampValueIPv4
ipHeaderSize = int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
case layers.LayerTypeIPv6:
mssClampValue = m.mssClampValueIPv6
ipHeaderSize = 40
default:
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue {
if currentMSS > mssClampValue {
mssOptionIndex = i
break
}
@@ -740,20 +849,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
@@ -768,7 +872,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex,
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
@@ -778,18 +882,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
// Zero out existing checksum
tcpLayer[16] = 0
tcpLayer[17] = 0
// Build pseudo-header checksum based on IP version
var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
switch d.decoded[0] {
case layers.LayerTypeIPv4:
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
case layers.LayerTypeIPv6:
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
}
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
}
pseudoSum += uint32(tcpLength)
pseudoSum += uint32(layers.IPProtocolTCP)
}
var sum = pseudoSum
sum := pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
@@ -827,6 +945,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
}
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
}
}
@@ -840,6 +961,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
}
d.dnatOrigPort = 0
@@ -899,15 +1023,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
return false
}
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
@@ -916,7 +1044,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true
}
@@ -1048,6 +1176,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true
}
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
// both families uniformly. The echo ID is in the first two payload bytes.
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
if len(d.icmp6.Payload) >= 2 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
}
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
switch d.icmp6.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
case layers.ICMPv6TypeEchoReply:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
default:
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
}
return id, tc
}
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
// ICMP rules since management sends a single ICMP rule for both families.
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
if ruleLayer == packetLayer {
return true
}
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
return true
}
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
return true
}
return false
}
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
if p.Addr().Is6() {
return layers.LayerTypeIPv6
}
return layers.LayerTypeIPv4
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
@@ -1071,8 +1241,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
return nftypes.TCP
case layers.LayerTypeUDP:
return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
case layers.LayerTypeICMPv4:
return nftypes.ICMP
case layers.LayerTypeICMPv6:
return nftypes.ICMPv6
default:
return nftypes.ProtocolUnknown
}
@@ -1093,7 +1265,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false
}
@@ -1106,10 +1278,18 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
}
// Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
ip4 := d.ip4
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
return true, true
if l == 1 {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
return true, true
}
case layers.LayerTypeIPv6:
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
// only decoded the IPv6 layer, the transport is in a fragment.
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
return true, true
}
}
}
@@ -1147,21 +1327,34 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
size,
)
// TODO: ICMPv6
case layers.LayerTypeICMPv6:
id, _ := icmpv6EchoFields(d)
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
id,
d.icmp6.TypeCode.Type(),
size,
)
}
return false
}
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
func (m *Manager) isSpecialICMP(d *decoder) bool {
if d.decoded[1] != layers.LayerTypeICMPv4 {
return false
switch d.decoded[1] {
case layers.LayerTypeICMPv4:
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
case layers.LayerTypeICMPv6:
icmpType := d.icmp6.TypeCode.Type()
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
icmpType == layers.ICMPv6TypePacketTooBig ||
icmpType == layers.ICMPv6TypeTimeExceeded
}
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
return false
}
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
@@ -1218,7 +1411,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
if payloadLayer != rule.protoLayer {
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue
}
@@ -1259,8 +1452,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
return false
}
@@ -1465,7 +1657,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
}
// traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP {
addr := m.wgIface.Address()
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
return true
}

View File

@@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) {
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
@@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValue = 1240
manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
}
srcIP := net.ParseIP("100.64.0.2")
@@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")

View File

@@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) {
}
}
func TestPeerACLFilteringIPv6(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100")
localIPv6 := netip.MustParseAddr("fd00::100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
wgNetV6 := netip.MustParsePrefix("fd00::/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
Network: wgNet,
IPv6: localIPv6,
IPv6Net: wgNetV6,
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
err = manager.UpdateLocalIPs()
require.NoError(t, err)
testCases := []struct {
name string
srcIP string
dstIP string
proto fw.Protocol
srcPort uint16
dstPort uint16
ruleIP string
ruleProto fw.Protocol
ruleDstPort *fw.Port
ruleAction fw.Action
shouldBeBlocked bool
}{
{
name: "IPv6: allow TCP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow UDP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow ICMPv6 from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: block TCP without rule",
srcIP: "fd00::2",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "IPv6: drop rule",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "IPv6: allow all protocols",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 9999,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "0.0.0.0",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
}
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(t, err)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
})
}
}
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
t.Helper()
src := net.ParseIP(srcIP)
dst := net.ParseIP(dstIP)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
}
// Detect address family
isV6 := src.To4() == nil
var err error
switch proto {
case fw.ProtocolTCP:
ipLayer.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
err = tcp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
case fw.ProtocolUDP:
ipLayer.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
if isV6 {
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
SrcIP: src,
DstIP: dst,
}
err = udp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
case fw.ProtocolICMP:
ipLayer.Protocol = layers.IPProtocolICMPv4
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
switch proto {
case fw.ProtocolTCP:
ip6.NextHeader = layers.IPProtocolTCP
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
case fw.ProtocolUDP:
ip6.NextHeader = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
case fw.ProtocolICMP:
ip6.NextHeader = layers.IPProtocolICMPv6
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
}
_ = icmp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip6)
}
} else {
ip4 := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: src,
DstIP: dst,
}
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ipLayer)
switch proto {
case fw.ProtocolTCP:
ip4.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
case fw.ProtocolUDP:
ip4.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
case fw.ProtocolICMP:
ip4.Protocol = layers.IPProtocolICMPv4
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip4)
}
}
require.NoError(t, err)
@@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) {
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
// but the ACL decision itself is correct.
func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
)
require.NoError(t, err)
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
fw.ProtocolALL,
nil,
nil,
fw.ActionDrop,
)
require.NoError(t, err)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
proto gopacket.LayerType
srcPort uint16
dstPort uint16
allowed bool
}{
{
name: "IPv6 TCP to allowed dest",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: true,
},
{
name: "IPv6 TCP wrong port",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 443,
allowed: false,
},
{
name: "IPv6 UDP not matched by TCP rule",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeUDP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeICMPv6,
allowed: false,
},
{
name: "IPv6 to denied subnet",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 source outside allowed range",
srcIP: netip.MustParseAddr("fe80::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
})
}
}

View File

@@ -0,0 +1,376 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
@@ -430,11 +582,16 @@ func TestProcessOutgoingHooks(t *testing.T) {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d
},
}
@@ -535,11 +692,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d
},
}
@@ -945,8 +1107,8 @@ func TestMSSClamping(t *testing.T) {
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
@@ -964,7 +1126,7 @@ func TestMSSClamping(t *testing.T) {
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
@@ -988,7 +1150,7 @@ func TestMSSClamping(t *testing.T) {
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
@@ -1160,13 +1322,18 @@ func TestShouldForward(t *testing.T) {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
err = d.decodePacket(buf.Bytes())
require.NoError(t, err)
return d
@@ -1226,6 +1393,44 @@ func TestShouldForward(t *testing.T) {
},
}
// Add IPv6 to the interface and test dual-stack cases
wgIPv6 := netip.MustParseAddr("fd00::1")
otherIPv6 := netip.MustParseAddr("fd00::2")
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{
IP: wgIP,
Network: netip.PrefixFrom(wgIP, 24),
IPv6: wgIPv6,
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
}
}
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {
name string
dstIP netip.Addr
expected bool
description string
}{
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
}
for _, tt := range v6Cases {
t.Run(tt.name, func(t *testing.T) {
manager.localForwarding = true
manager.netstack = false
decoder := createTCPDecoder(8080)
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Configure manager

View File

@@ -1,7 +1,8 @@
package forwarder
import (
"fmt"
"net"
"strconv"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -47,17 +48,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
raw := pkt.NetworkHeader().View().AsSlice()
if len(raw) == 0 {
continue
}
var address tcpip.Address
if raw[0]>>4 == 6 {
address = header.IPv6(raw).DestinationAddress()
} else {
address = header.IPv4(raw).DestinationAddress()
}
if err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()); err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err)
continue
}
@@ -103,5 +110,7 @@ type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
" → " +
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
}

View File

@@ -14,6 +14,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -36,25 +37,31 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
ipv6 tcpip.Address
netstack bool
hasRawICMPAccess bool
hasRawICMPv6Access bool
pingSemaphore chan struct{}
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
HandleLocal: false,
})
@@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
PrefixLen: iface.Address().Network.Bits(),
},
}
@@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
if v6 := iface.Address().IPv6; v6.IsValid() {
v6Addr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16(v6.As16()),
PrefixLen: iface.Address().IPv6Net.Bits(),
},
}
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
}
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
@@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("creating default subnet: %w", err)
}
defaultSubnetV6, err := tcpip.NewSubnet(
tcpip.AddrFrom16([16]byte{}),
tcpip.MaskFromBytes(make([]byte, 16)),
)
if err != nil {
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
@@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
{Destination: defaultSubnet, NIC: nicID},
{Destination: defaultSubnetV6, NIC: nicID},
})
ctx, cancel := context.WithCancel(context.Background())
@@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
ipv6: addrFromNetipAddr(iface.Address().IPv6),
pingSemaphore: make(chan struct{}, 3),
}
@@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
// network layer. This avoids duplicate echo replies (v4) and the v6
// auto-reply bug where gVisor responds at the network layer before
// our transport handler fires.
f.checkICMPCapability()
@@ -140,8 +170,30 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
if len(payload) == 0 {
return fmt.Errorf("empty packet")
}
var protoNum tcpip.NetworkProtocolNumber
switch payload[0] >> 4 {
case 4:
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv4.ProtocolNumber
case 6:
if len(payload) < header.IPv6MinimumSize {
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv6.ProtocolNumber
default:
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -150,11 +202,95 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
}
return nil
}
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
// and auto-replies to all v6 echo requests in promiscuous mode.
//
// Unlike gVisor's network layer, this does not validate ICMP checksums or
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
ip := header.IPv4(payload)
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
return 0, src, dst, false
}
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
return 0, src, dst, false
}
ipHdrLen = int(ip.HeaderLength())
if len(payload)-ipHdrLen < header.ICMPv4MinimumSize {
return 0, src, dst, false
}
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
ip := header.IPv6(payload)
if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) {
return 0, src, dst, false
}
ipHdrLen = header.IPv6MinimumSize
if len(payload)-ipHdrLen < header.ICMPv6MinimumSize {
return 0, src, dst, false
}
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
var (
ipHdrLen int
srcAddr tcpip.Address
dstAddr tcpip.Address
ok bool
)
switch payload[0] >> 4 {
case 4:
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
case 6:
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
}
if !ok {
return false
}
// Let gVisor handle ICMP destined for our own addresses natively.
// Its network-layer auto-reply is correct and efficient for local traffic.
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
return false
}
id := stack.TransportEndpointID{
LocalAddress: dstAddr,
RemoteAddress: srcAddr,
}
// Build a PacketBuffer with headers consumed the same way gVisor would.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
return false
}
icmpPayload := payload[ipHdrLen:]
if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok {
return false
}
if payload[0]>>4 == 6 {
return f.handleICMPv6(id, pkt)
}
return f.handleICMP(id, pkt)
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
@@ -167,11 +303,14 @@ func (f *Forwarder) Stop() {
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
}
return addr.AsSlice()
if f.netstack && f.ipv6.Equal(addr) {
return netip.IPv6Loopback()
}
return addrToNetipAddr(addr)
}
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
@@ -205,23 +344,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
}
}
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
if !addr.IsValid() {
return tcpip.Address{}
}
if addr.Is4() {
return tcpip.AddrFrom4(addr.As4())
}
return tcpip.AddrFrom16(addr.As16())
}
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
switch addr.Len() {
case 4:
return netip.AddrFrom4(addr.As4())
case 16:
return netip.AddrFrom16(addr.As16())
default:
return netip.Addr{}
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
}
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
conn, err := lc.ListenPacket(ctx, network, addr)
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
return false
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
logger.Debug1("forwarder: raw %s socket access available", network)
return true
}

View File

@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true
@@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel()
network, listenAddr := "ip4:icmp", "0.0.0.0"
if v6 {
network, listenAddr = "ip6:ipv6-icmp", "::"
}
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
conn, err := lc.ListenPacket(ctx, network, listenAddr)
if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err)
}
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
dst := &net.IPAddr{IP: dstIP.AsSlice()}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
@@ -102,7 +107,7 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, false, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
@@ -150,19 +155,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := addrToNetipAddr(id.LocalAddress)
proto := nftypes.ICMP
if srcIp.Is6() {
proto = nftypes.ICMPv6
}
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
ICMPType: icmpType,
ICMPCode: icmpCode,
Protocol: proto,
SourceIP: srcIp,
DestIP: dstIp,
ICMPType: icmpType,
ICMPCode: icmpCode,
RxBytes: rxBytes,
TxBytes: txBytes,
@@ -209,26 +218,159 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// handleICMPv6 handles ICMPv6 packets from the network stack.
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
if icmpHdr.Type() == header.ICMPv6EchoRequest {
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
if !f.hasRawICMPv6Access {
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
}
return true
}
// handleICMPv6Echo handles ICMPv6 echo requests using the ping binary.
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}()
default:
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
}
return true
}
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyHdr := header.ICMPv6(replyICMP)
replyHdr.SetType(header.ICMPv6EchoReply)
replyHdr.SetChecksum(0)
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: replyHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, replyICMP)
}
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv6MinimumSize)
ip := header.IPv6(ipHdr)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(icmpPayload)),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: 64,
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
return 0
}
return len(fullPacket)
}
const (
pingBin = "ping"
ping6Bin = "ping6"
)
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
isV6 := target.Is6()
timeoutStr := fmt.Sprintf("%d", timeoutSec)
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
}
}

View File

@@ -2,10 +2,9 @@ package forwarder
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"sync"
"github.com/google/uuid"
@@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
@@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := addrToNetipAddr(id.LocalAddress)
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
@@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := addrToNetipAddr(id.LocalAddress)
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,

View File

@@ -4,89 +4,32 @@ import (
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
type localIPManager struct {
mu sync.RWMutex
// fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [256]*ipv4LowBitmap
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
// atomically so reads are lock-free.
type localIPSnapshot struct {
ips map[netip.Addr]struct{}
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
type localIPManager struct {
snapshot atomic.Pointer[localIPSnapshot]
}
func newLocalIPManager() *localIPManager {
return &localIPManager{}
m := &localIPManager{}
m.snapshot.Store(&localIPSnapshot{
ips: make(map[netip.Addr]struct{}),
})
return m
}
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue
}
addr, ok := netip.AddrFromSlice(ip)
parsed, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
parsed = parsed.Unmap()
ips[parsed] = struct{}{}
*addresses = append(*addresses, parsed)
}
}
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
@@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
}()
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []netip.Addr
ips := make(map[netip.Addr]struct{})
var addresses []netip.Addr
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}
// loopback
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
ips[netip.IPv6Loopback()] = struct{}{}
if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
ip := iface.Address().IP
ips[ip] = struct{}{}
addresses = append(addresses, ip)
if v6 := iface.Address().IPv6; v6.IsValid() {
ips[v6] = struct{}{}
addresses = append(addresses, v6)
}
}
@@ -144,26 +88,27 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
processInterface(intf, ips, &addresses)
}
}
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
m.snapshot.Store(&localIPSnapshot{ips: ips})
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
log.Debugf("Local IP addresses: %v", addresses)
return nil
}
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
s := m.snapshot.Load()
if ip.Is4() && ip.As4()[0] == 127 {
return true
}
m.mu.RLock()
defer m.mu.RUnlock()
return m.checkBitmapBit(ip.AsSlice())
_, found := s.ips[ip]
return found
}

View File

@@ -0,0 +1,72 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func setupManager(b *testing.B) *localIPManager {
b.Helper()
m := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
if err := m.UpdateLocalIPs(mock); err != nil {
b.Fatalf("UpdateLocalIPs: %v", err)
}
return m
}
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("100.64.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("8.8.8.8")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("fd00::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("2001:db8::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_loopback(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("127.0.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}

View File

@@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) {
expected: false,
},
{
name: "IPv6 address",
name: "IPv6 address matches",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"),
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::1"),
expected: true,
},
{
name: "IPv6 address does not match",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::99"),
expected: false,
},
{
name: "No aliasing between similar IPs",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("fe80::1"),
testIP: netip.MustParseAddr("192.168.0.17"),
expected: false,
},
{
name: "IPv6 loopback",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
},
testIP: netip.MustParseAddr("::1"),
expected: true,
},
}
for _, tt := range tests {
@@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap
bitmapManager := newLocalIPManager()
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := newLocalIPManager()
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := newLocalIPManager()
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -5,6 +5,8 @@ import (
"context"
"fmt"
"io"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -16,9 +18,18 @@ const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
defaultLogChanSize = 1000
)
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32
const (
@@ -69,7 +80,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
msgChannel: make(chan logMessage, getLogChannelSize()),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {

View File

@@ -13,8 +13,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
@@ -25,10 +23,33 @@ const (
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
ipv4SrcOffset = 12
ipv4DstOffset = 16
// IP address offsets in IPv6 header
ipv6SrcOffset = 8
ipv6DstOffset = 24
// IPv6 fixed header length
ipv6HeaderLen = 40
)
// ipHeaderLen returns the IP header length based on the decoded layer type.
func ipHeaderLen(d *decoder) (int, error) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
n := int(d.ip4.IHL) * 4
if n < 20 {
return 0, errInvalidIPHeaderLength
}
return n, nil
case layers.LayerTypeIPv6:
return ipv6HeaderLen, nil
default:
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
@@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
_, dstIP := extractPacketIPs(packetData, d)
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
return false
}
@@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
srcIP, _ := extractPacketIPs(packetData, d)
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
return false
}
@@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
case layers.LayerTypeIPv6:
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
}
return src, dst
}
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
}
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
case layers.LayerTypeIPv6:
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
default:
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is4() {
return ErrIPv4Only
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
}
offset := ipv4DstOffset
if isSource {
offset = ipv4SrcOffset
}
var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
copy(oldIP[:], packetData[offset:offset+4])
newIPBytes := newIP.As4()
copy(packetData[offset:offset+4], newIPBytes[:])
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
// Recalculate IPv4 header checksum
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
// Update transport checksums incrementally
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
m.updateICMPChecksum(packetData, hdrLen)
}
}
return nil
}
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is6() {
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
}
offset := ipv6DstOffset
if isSource {
offset = ipv6SrcOffset
}
var oldIP [16]byte
copy(oldIP[:], packetData[offset:offset+16])
newIPBytes := newIP.As16()
copy(packetData[offset:offset+16], newIPBytes[:])
// IPv6 has no header checksum, only update transport checksums
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv6:
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
}
}
return nil
}
@@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+4 {
return
}
checksumOffset := icmpStart + 2
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
@@ -358,9 +449,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
@@ -515,12 +606,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
}
tcpStart := ipHeaderLen
tcpStart := hdrLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
@@ -546,12 +637,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
}
udpStart := ipHeaderLen
udpStart := hdrLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}

View File

@@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.decodePacket(testPacket)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
@@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.decodePacket(packet)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {

View File

@@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
err := d.decodePacket(packetData)
require.NoError(t, err)
return d
}

View File

@@ -112,10 +112,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
}
func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip}
ipLayer, err := p.buildIPLayer()
if err != nil {
return nil, err
}
pktLayers := []gopacket.SerializableLayer{ipLayer}
transportLayer, err := p.buildTransportLayer(ip)
transportLayer, err := p.buildTransportLayer(ipLayer)
if err != nil {
return nil, err
}
@@ -129,30 +132,43 @@ func (p *PacketBuilder) Build() ([]byte, error) {
return serializePacket(pktLayers)
}
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
if p.SrcIP.Is4() != p.DstIP.Is4() {
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
}
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
if p.SrcIP.Is6() {
return &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: proto,
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}, nil
}
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
Protocol: proto,
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}
}, nil
}
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
switch p.Protocol {
case "tcp":
return p.buildTCPLayer(ip)
return p.buildTCPLayer(ipLayer)
case "udp":
return p.buildUDPLayer(ip)
return p.buildUDPLayer(ipLayer)
case "icmp":
return p.buildICMPLayer()
return p.buildICMPLayer(ipLayer)
default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
}
}
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort),
@@ -164,24 +180,37 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL
PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
}
return []gopacket.SerializableLayer{tcp}, nil
}
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
}
return []gopacket.SerializableLayer{udp}, nil
}
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
if p.SrcIP.Is6() || p.DstIP.Is6() {
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
}
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
_ = icmp.SetNetworkLayerForChecksum(nl)
}
return []gopacket.SerializableLayer{icmp}, nil
}
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
}
@@ -204,14 +233,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
return buf.Bytes(), nil
}
func getIPProtocolNumber(protocol fw.Protocol) int {
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
switch protocol {
case fw.ProtocolTCP:
return int(layers.IPProtocolTCP)
return layers.IPProtocolTCP
case fw.ProtocolUDP:
return int(layers.IPProtocolUDP)
return layers.IPProtocolUDP
case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4)
if isV6 {
return layers.IPProtocolICMPv6
}
return layers.IPProtocolICMPv4
default:
return 0
}
@@ -234,7 +266,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace := &PacketTrace{Direction: direction}
// Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace
}
@@ -256,6 +288,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP"
case layers.LayerTypeICMPv6:
trace.Protocol = "ICMPv6"
}
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
@@ -415,7 +449,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
@@ -434,7 +468,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
@@ -444,7 +478,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
@@ -509,7 +543,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
srcIP, _ := extractPacketIPs(packetData, d)
translated := m.translateInboundReverse(packetData, d)
if translated {
@@ -539,7 +573,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
_, dstIP := extractPacketIPs(packetData, d)
translated := m.translateOutboundDNAT(packetData, d)
if translated {

View File

@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
@@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
opts := []grpc.DialOption{
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
@@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
}
opts = append(opts, extraOpts...)
conn, err := grpc.DialContext(connCtx, addr, opts...)
if err != nil {
return nil, fmt.Errorf("dial context: %w", err)
}

View File

@@ -5,20 +5,18 @@ package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
log.Errorf("failed to listen on uapi socket: %v", err)
_ = uapiSock.Close()
return nil, err
}

View File

@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
@@ -111,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
@@ -558,7 +566,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
host, portStr, err := net.SplitHostPort(val)
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue

View File

@@ -2,7 +2,7 @@ package device
// TunAdapter is an interface for create tun device from external service
type TunAdapter interface {
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
UpdateAddr(address string) error
ProtectSocket(fd int32) bool
}

View File

@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
return nil, err

View File

@@ -131,23 +131,32 @@ func (t *TunDevice) Device() *device.Device {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
return err
if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil {
return fmt.Errorf("add v4 address: %s: %w", string(out), err)
}
// dummy ipv6 so routing works
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
// Assign a dummy link-local so macOS enables IPv6 on the tun device.
// When a real overlay v6 is present, use that instead.
v6Addr := "fe80::/64"
if t.address.HasIPv6() {
v6Addr = t.address.IPv6String()
}
if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil {
log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err)
t.address.ClearIPv6()
}
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
if out, err := routeCmd.CombinedOutput(); err != nil {
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
return err
if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil {
return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err)
}
if t.address.HasIPv6() {
if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil {
log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err)
t.address.ClearIPv6()
}
}
return nil
}

View File

@@ -29,8 +29,9 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -151,8 +151,11 @@ func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement
// UpdateAddr updates the device address. On iOS the tunnel is managed by the
// NetworkExtension, so we only store the new value. The extension picks up the
// change on the next tunnel reconfiguration.
func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error {
t.address = addr
return nil
}

View File

@@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
// assignAddr Adds IP address to the tunnel interface
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
return t.link.assignAddr(&t.address)
}
func (t *TunKernelDevice) GetNet() *netstack.Net {

View File

@@ -3,6 +3,7 @@ package device
import (
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
@@ -63,8 +64,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
addresses := []netip.Addr{t.address.IP}
if t.address.HasIPv6() {
addresses = append(addresses, t.address.IPv6)
}
log.Debugf("netstack using addresses: %v", addresses)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu))
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil {
@@ -79,10 +84,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type USPDevice struct {
type TunDevice struct {
name string
address wgaddr.Address
port int
@@ -30,10 +30,10 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
return &TunDevice{
name: name,
address: address,
port: port,
@@ -43,7 +43,7 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
func (t *USPDevice) Create() (WGConfigurer, error) {
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil {
@@ -75,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -95,12 +95,12 @@ func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
func (t *USPDevice) Close() error {
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
@@ -115,39 +115,39 @@ func (t *USPDevice) Close() error {
return nil
}
func (t *USPDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *USPDevice) MTU() uint16 {
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *USPDevice) DeviceName() string {
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *USPDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// Device returns the wireguard device
func (t *USPDevice) Device() *device.Device {
func (t *TunDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error {
func (t *TunDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(t.address)
return link.assignAddr(&t.address)
}
func (t *USPDevice) GetNet() *netstack.Net {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -87,7 +87,19 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
err = nbiface.Set()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err)
return nil, fmt.Errorf("set IPv4 interface MTU: %s", err)
}
if t.address.HasIPv6() {
nbiface6, err := luid.IPInterface(windows.AF_INET6)
if err != nil {
log.Warnf("failed to get IPv6 interface for MTU: %v", err)
} else {
nbiface6.NLMTU = uint32(t.mtu)
if err := nbiface6.Set(); err != nil {
log.Warnf("failed to set IPv6 interface MTU: %v", err)
}
}
}
err = t.assignAddr()
if err != nil {
@@ -178,8 +190,21 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
v4Prefix := t.address.Prefix()
if t.address.HasIPv6() {
v6Prefix := t.address.IPv6Prefix()
log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name)
if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil {
log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err)
t.address.ClearIPv6()
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
}
return nil
}
log.Debugf("adding address %s to interface: %s", v4Prefix, t.name)
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
}
func (t *TunDevice) GetNet() *netstack.Net {

View File

@@ -1,8 +0,0 @@
//go:build (!linux && !freebsd) || android
package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
}

View File

@@ -1,18 +0,0 @@
package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
// we are currently do not use it, since it is required to add wireguard kernel support to
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
// - https://github.com/mdlayher/socket
// TODO: implement kernel space
return false
}
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
}

View File

@@ -0,0 +1,13 @@
//go:build !linux || android
package device
// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available.
func WireGuardModuleIsLoaded() bool {
return false
}
// ModuleTunIsLoaded reports whether the tun device is available.
func ModuleTunIsLoaded() bool {
return true
}

View File

@@ -2,6 +2,7 @@ package device
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
@@ -57,32 +58,32 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)
}
ip := address.IP.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name)
err = link.AssignAddr(ip, mask)
if err != nil {
if err := link.AssignAddr(address.IP.String(), mask); err != nil {
return fmt.Errorf("assign addr: %w", err)
}
err = link.Up()
if err != nil {
if address.HasIPv6() {
log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name)
cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err)
address.ClearIPv6()
}
}
if err := link.Up(); err != nil {
return fmt.Errorf("up: %w", err)
}

View File

@@ -4,6 +4,8 @@ package device
import (
"fmt"
"net"
"net/netip"
"os"
log "github.com/sirupsen/logrus"
@@ -92,7 +94,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {
@@ -110,20 +112,16 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
}
name := l.attrs.Name
addrStr := address.String()
log.Debugf("adding address %s to interface: %s", addrStr, name)
addr, err := netlink.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parse addr: %w", err)
if err := l.addAddr(name, address.Prefix()); err != nil {
return err
}
err = netlink.AddrAdd(l, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", name, addrStr)
} else if err != nil {
return fmt.Errorf("add addr: %w", err)
if address.HasIPv6() {
if err := l.addAddr(name, address.IPv6Prefix()); err != nil {
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err)
address.ClearIPv6()
}
}
// On linux, the link must be brought up
@@ -133,3 +131,22 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
return nil
}
func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error {
log.Debugf("adding address %s to interface: %s", prefix, ifaceName)
addr := &netlink.Addr{
IPNet: &net.IPNet{
IP: prefix.Addr().AsSlice(),
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
},
}
if err := netlink.AddrAdd(l, addr); os.IsExist(err) {
log.Infof("interface %s already has the address: %s", ifaceName, prefix)
} else if err != nil {
return fmt.Errorf("add addr %s: %w", prefix, err)
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More