Compare commits

...

119 Commits

Author SHA1 Message Date
mlsmaycon
0163377b81 add logs for dns routes 2026-02-04 16:44:59 +01:00
Vlad
d488f58311 [management] fix set disconnected status for connected peer (#5247) 2026-02-04 11:44:46 +01:00
Pascal Fischer
6fdc00ff41 [management] adding account id validation to accessible peers handler (#5246) 2026-02-03 17:30:02 +01:00
Misha Bragin
b20d484972 [docs] Add selfhosting video (#5235) 2026-02-01 16:06:36 +01:00
Vlad
8931293343 [management] run cancelPeerRoutinesWithoutLock in sync (#5234) 2026-02-01 15:44:27 +01:00
Vlad
7b830d8f72 disable sync lim (#5233) 2026-02-01 14:37:00 +01:00
Misha Bragin
3a0cf230a1 Disable local users for a smooth single-idp mode (#5226)
Add LocalAuthDisabled option to embedded IdP configuration

This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external
identity providers (Google, OIDC, etc.).

This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page.

Key changes:

Added LocalAuthDisabled field to EmbeddedIdPConfig
Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth
Validation prevents disabling local auth if no external connectors are configured
Existing local users are preserved when disabled and can login again when re-enabled
Operations are idempotent (disabling already disabled is a no-op)
2026-02-01 14:26:22 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
Zoltan Papp
5333e55a81 Fix WG watcher missing initial handshake (#5213)
Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp.

Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup.
2026-01-29 16:58:10 +01:00
Viktor Liu
81c11df103 [management] Streamline domain validation (#5211) 2026-01-29 13:51:44 +01:00
Viktor Liu
f74bc48d16 [Client] Stop NetBird on firewall init failure (#5208) 2026-01-29 11:05:06 +01:00
Vlad
0169e4540f [management] fix skip of ephemeral peers on deletion (#5206) 2026-01-29 10:58:45 +01:00
Vlad
cead3f38ee [management] fix ephemeral peers being not removed (#5203) 2026-01-28 18:24:12 +01:00
Zoltan Papp
b55262d4a2 [client] Refactor/optimise raw socket headers (#5174)
Pre-create and reuse packet headers to eliminate per-packet allocations.
2026-01-28 15:06:59 +01:00
Zoltan Papp
2248ff392f Remove redundant square bracket trimming in USP endpoint parsing (#5197) 2026-01-27 20:10:59 +01:00
Viktor Liu
06966da012 [client] Support non-PTY no-command interactive SSH sessions (#5093) 2026-01-27 11:05:04 +01:00
Viktor Liu
d4f7df271a [cllient] Don't track ebpf traffic in conntrack (#5166) 2026-01-27 11:04:23 +01:00
Maycon Santos
5299549eb6 [client] Update fyne and add exit menu retry (#5187)
* Update fyne and add exit menu retry

- Fix an extra arrow on macos by updating fyne/systray

* use systray.TrayOpenedCh instead of loop and retry
2026-01-27 09:52:55 +01:00
Misha Bragin
7d791620a6 Add user invite link feature for embedded IdP (#5157) 2026-01-27 09:42:20 +01:00
Zoltan Papp
44ab454a13 [management] Fix peer deletion error handling (#5188)
When a deleted peer tries to reconnect, GetUserIDByPeerKey was returning
Internal error instead of NotFound, causing clients to retry indefinitely
instead of recognizing the unrecoverable PermissionDenied error.

This fix:
1. Updates GetUserIDByPeerKey to properly return NotFound when peer doesn't exist
2. Updates Sync handler to convert NotFound to PermissionDenied with message
   'peer is not registered', matching the behavior of GetAccountIDForPeerKey

Fixes the regression introduced in v0.61.1 where deleted peers would see:
- Before: 'rpc error: code = Internal desc = failed handling request' (retry loop)
- After: 'rpc error: code = PermissionDenied desc = peer is not registered' (exits)
2026-01-26 23:15:34 +01:00
Misha Bragin
11f50d6c38 Include default groups claim in CLI audience (#5186) 2026-01-26 22:26:29 +01:00
Zoltan Papp
05af39a69b [client] Add IPv6 support to UDP WireGuard proxy (#5169)
* Add IPv6 support to UDP WireGuard proxy

Add IPv6 packet header support in UDP raw socket proxy
to handle both IPv4 and IPv6 source addresses.
Refactor error handling in proxy bind implementations
to validate endpoints before acquiring locks.
2026-01-26 14:03:32 +01:00
Viktor Liu
074df56c3d [client] Fix flaky JWT SSH test (#5181) 2026-01-26 09:30:00 +01:00
Maycon Santos
2381e216e4 Fix validator message with warn (#5168) 2026-01-24 17:49:25 +01:00
Zoltan Papp
ded04b7627 [client] Consolidate authentication logic (#5010)
* Consolidate authentication logic

- Moving auth functions from client/internal to client/internal/auth package
- Creating unified auth.Auth client with NewAuth() constructor
- Replacing direct auth function calls with auth client methods
- Refactoring device flow and PKCE flow implementations
- Updating iOS/Android/server code to use new auth client API

* Refactor PKCE auth and login methods

- Remove unnecessary internal package reference in PKCE flow test
- Adjust context assignment placement in iOS and Android login methods
2026-01-23 22:28:32 +01:00
Maycon Santos
67211010f7 [client, gui] fix exit nodes menu on reconnect, remove tooltips (#5167)
* [client, gui] fix exit nodes menu on reconnect

clean s.exitNodeStates when disconnecting

* disable tooltip for exit nodes and settings
2026-01-23 18:39:45 +01:00
Maycon Santos
c61568ceb4 [client] Change default rosenpass log level (#5137)
* Change default rosenpass log level

- Add support to environment configuration
- Change default log level to info

* use .String() for print log level
2026-01-23 18:06:54 +01:00
Vlad
737d6061bf [management] ephemeral peers track on login (#5165) 2026-01-23 18:05:22 +01:00
Zoltan Papp
ee3a67d2d8 [client] Fix/health result in bundle (#5164)
* Add support for optional status refresh callback during debug bundle generation

* Always update wg status

* Remove duplicated wg status call
2026-01-23 17:06:07 +01:00
Viktor Liu
1a32e4c223 [client] Fix IPv4-only in bind proxy (#5154) 2026-01-23 15:15:34 +01:00
Viktor Liu
269d5d1cba [client] Try next DNS upstream on SERVFAIL/REFUSED responses (#5163) 2026-01-23 11:59:52 +01:00
Bethuel Mmbaga
a1de2b8a98 [management] Move activity store encryption to shared crypt package (#5111) 2026-01-22 15:01:13 +03:00
Viktor Liu
d0221a3e72 [client] Add cpu profile to debug bundle (#4700) 2026-01-22 12:24:12 +01:00
Bethuel Mmbaga
8da23daae3 [management] Fix activity event initiator for user group changes (#5152) 2026-01-22 14:18:46 +03:00
Viktor Liu
f86022eace [client] Hide forwarding rules in status when count is zero (#5149)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 10:01:08 +01:00
Viktor Liu
ee54827f94 [client] Add IPv6 support to usersace bind (#5147) 2026-01-22 10:20:43 +08:00
Zoltan Papp
e908dea702 [client] Extend WG watcher for ICE connection too (#5133)
Extend WG watcher for ICE connection too
2026-01-21 10:42:13 +01:00
Maycon Santos
030650a905 [client] Fix RFC 4592 wildcard matching for existing domain names (#5145)
Per RFC 4592 section 2.2.1, wildcards should only match when the queried
name does not exist in the zone. Previously, if host.example.com had an
A record and *.example.com had an AAAA record, querying AAAA for
host.example.com would incorrectly return the wildcard AAAA instead of
NODATA.

Now the resolver checks if the domain exists (with any record type)
before falling back to wildcard matching, returning proper NODATA
responses for existing names without the requested record type.
2026-01-21 08:48:32 +01:00
Misha Bragin
e01998815e [infra] add embedded STUN to getting started (#5141) 2026-01-20 19:01:34 +01:00
Zoltan Papp
07e4a5a23c Fixes profile switching and repeated down/up command failures. (#5142)
When Down() and Up() are called in quick succession, the connectWithRetryRuns goroutine could set ErrResetConnection after Down() had cleared the state, causing the subsequent Up() to fail.

Fix by waiting for the goroutine to exit (via clientGiveUpChan) before Down() returns. Uses a 5-second timeout to prevent RPC timeouts while ensuring the goroutine completes in most cases.
2026-01-20 18:22:37 +01:00
Diego Romar
b3a2992a10 [client/android] - Fix Rosenpass connectivity for Android peers (#5044)
* [client] Add WGConfigurer interface

To allow Rosenpass to work both with kernel
WireGuard via wgctrl (default behavior) and
userspace WireGuard via IPC on Android/iOS
using WGUSPConfigurer

* [client] Remove Rosenpass debug logs

* [client] Return simpler peer configuration in outputKey method

ConfigureDevice, the method previously used in
outputKey via wgClient to update the device's
properties, is now defined in the WGConfigurer
interface and implemented both in kernel_unix and
usp configurers.

PresharedKey datatype was also changed from
boolean to [32]byte to compare it
to the original NetBird PSK, so that Rosenpass
may replace it with its own when necessary.

* [client] Remove unused field

* [client] Replace usage of WGConfigurer

Replaced with preshared key setter interface,
which only defines a method to set / update the preshared key.

Logic has been migrated from rosenpass/netbird_handler to client/iface.

* [client] Use same default peer keepalive value when setting preshared keys

* [client] Store PresharedKeySetter iface in rosenpass manager

To avoid no-op if SetInterface is called before generateConfig

* [client] Add mutex usage in rosenpass netbird handler

* [client] change implementation setting Rosenpass preshared key

Instead of providing a method to configure a device (device/interface.go),
it forwards the new parameters to the configurer (either
kernel_unix.go / usp.go).

This removes dependency on reading FullStats, and makes use of a common
method (buildPresharedKeyConfig in configurer/common.go) to build a
minimal WG config that only sets/updates the PSK.

netbird_handler.go now keeps s list of initializedPeers to choose whether
to set the value of "UpdateOnly" when calling iface.SetPresharedKey.

* [client] Address possible race condition

Between outputKey calls and peer removal; it
checks again if the peer still exists in the
peers map before inserting it in the
initializedPeers map.

* [client] Add psk Rosenpass-initialized check

On client/internal/peer/conn.go, the presharedKey
function would always return the current key
set in wgConfig.presharedKey.

This would eventually overwrite a key set
by Rosenpass if the feature is active.

The purpose here is to set a handler that will
check if a given peer has its psk initialized
by Rosenpass to skip updating the psk
via updatePeer (since it calls presharedKey
method in conn.go).

* Add missing updateOnly flag setup for usp peers

* Change common.go buildPresharedKeyConfig signature

PeerKey datatype changed from string to
wgTypes.Key. Callers are responsible for parsing
a peer key with string datatype.
2026-01-20 13:26:51 -03:00
Maycon Santos
202fa47f2b [client] Add support to wildcard custom records (#5125)
* **New Features**
  * Wildcard DNS fallback for eligible query types (excluding NS/SOA): attempts wildcard records when no exact match, rewrites wildcard names back to the original query, and rotates responses; preserves CNAME resolution.

* **Tests**
  * Vastly expanded coverage for wildcard behaviors, precedence, multi-record round‑robin, multi-type chains, multi-hop and cross-zone scenarios, and edge cases (NXDOMAIN/NODATA, fallthrough).

* **Chores**
  * CI lint config updated to ignore an additional codespell entry.
2026-01-20 17:21:25 +01:00
Misha Bragin
4888021ba6 Add missing activity events to the API response (#5140) 2026-01-20 15:12:22 +01:00
Misha Bragin
a0b0b664b6 Local user password change (embedded IdP) (#5132) 2026-01-20 14:16:42 +01:00
Diego Romar
50da5074e7 [client] change notifyDisconnected call (#5138)
On handleJobStream, when handling error codes 
from receiveJobRequest in the switch-case, 
notifying disconnected in cases where it isn't a 
disconnection breaks connection status reporting 
on mobile peers.

This commit changes it so it isn't called on
Canceled or Unimplemented status codes.
2026-01-20 07:14:33 -03:00
Zoltan Papp
58daa674ef [Management/Client] Trigger debug bundle runs from API/Dashboard (#4592) (#4832)
This PR adds the ability to trigger debug bundle generation remotely from the Management API/Dashboard.
2026-01-19 11:22:16 +01:00
Maycon Santos
245481f33b [client] fix: client/Dockerfile to reduce vulnerabilities (#5119)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-01-16 18:05:41 +01:00
shuuri-labs
b352ab84c0 Feat/quickstart reverse proxy assistant (#5100)
* add external reverse proxy config steps to quickstart script

* remove generated files

* - Remove 'press enter' prompt from post-traefik config since traefik requires no manual config
- Improve npm flow (ask users for docker network, user container names in config)

* fixes for npm flow

* nginx flow fixes

* caddy flow fixes

* Consolidate NPM_NETWORK, NGINX_NETWORK, CADDY_NETWORK into single
EXTERNAL_PROXY_NETWORK variable. Add read_proxy_docker_network()
function that prompts for Docker network for options 2-4 (Nginx,
NPM, Caddy). Generated configs now use container names when a
Docker network is specified.

* fix https for traefik

* fix sonar code smells

* fix sonar smell (add return to render_dashboard_env)

* added tls instructions to nginx flow

* removed unused bind_addr variable from quickstart.sh

* Refactor getting-started.sh for improved maintainability

Break down large functions into focused, single-responsibility components:
- Split init_environment() into 6 initialization functions
- Split print_post_setup_instructions() into 6 proxy-specific functions
- Add section headers for better code organization
- Fix 3 code smell issues (unused bind_addr variables)
- Add TLS certificate documentation for Nginx
- Link reverse proxy names to docs sections

Reduces largest function from 205 to ~90 lines while maintaining
single-file distribution. No functional changes.

* - Remove duplicate network display logic in Traefik instructions
- Use upstream_host instead of bind_addr for NPM forward hostname
- Use upstream_host instead of bind_addr in manual proxy route examples
- Prevents displaying invalid 0.0.0.0 as connection target in setup instructions

* add wait_management_direct to caddy flow to ensure script waits until containers are running/passing healthchecks before reporting 'done!'
2026-01-16 17:42:28 +01:00
ressys1978
3ce5d6a4f8 [management] Add idp timeout env variable (#4647)
Introduced the NETBIRD_IDP_TIMEOUT environment variable to the management service. This allows configuring a timeout for supported IDPs. If the variable is unset or contains an invalid value, a default timeout of 10 seconds is used as a fallback.

This is needed for larger IDP environments where 10s is just not enough time.
2026-01-16 16:23:37 +01:00
Misha Bragin
4c2eb2af73 [management] Skip email_verified if not present (#5118) 2026-01-16 16:01:39 +01:00
Misha Bragin
daf1449174 [client] Remove duplicate audiences check (#5117) 2026-01-16 14:25:02 +02:00
Misha Bragin
1ff7abe909 [management, client] Fix SSH server audience validator (#5105)
* **New Features**
  * SSH server JWT validation now accepts multiple audiences with backward-compatible handling of the previous single-audience setting and a guard ensuring at least one audience is configured.
* **Tests**
  * Test suites updated and new tests added to cover multiple-audience scenarios and compatibility with existing behavior.
* **Other**
  * Startup logging enhanced to report configured audiences for JWT auth.
2026-01-16 12:28:17 +01:00
Bethuel Mmbaga
067c77e49e [management] Add custom dns zones (#4849) 2026-01-16 12:12:05 +03:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
Viktor Liu
520d9c66cf [client] Fix netstack upstream dns and add wasm debug methods (#4648) 2026-01-14 13:56:16 +01:00
Misha Bragin
ff10498a8b Feature/embedded STUN (#5062) 2026-01-14 13:13:30 +01:00
Zoltan Papp
00b747ad5d Handle fallback for invalid loginuid in ui-post-install.sh. (#5099) 2026-01-14 09:53:14 +01:00
Zoltan Papp
d9118eb239 [client] Fix WASM peer connection to lazy peers (#5097)
WASM peers now properly initiate relay connections instead of waiting for offers that lazy peers won't send.
2026-01-13 13:33:15 +01:00
Nima Sadeghifard
94de656fae [misc] Add hiring announcement with link to careers.netbird.io (#5095) 2026-01-12 19:06:28 +01:00
Misha Bragin
37abab8b69 [management] Check config compatibility (#5087)
* Enforce HttpConfig overwrite when embeddedIdp is enabled

* Disable offline_access scope in dashboard by default

* Add group propagation foundation to embedded idp

* Require groups scope in dex config for okt and pocket

* remove offline_access from device default scopes
2026-01-12 17:09:03 +01:00
Viktor Liu
b12c084a50 [client] Fall through dns chain for custom dns zones (#5081) 2026-01-12 13:56:39 +01:00
Viktor Liu
394ad19507 [client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046) 2026-01-12 12:35:38 +01:00
Misha Bragin
614e7d5b90 Validate OIDC issuer when creating or updating (#5074) 2026-01-09 09:45:43 -05:00
Misha Bragin
f7967f9ae3 Feature/resolve local jwks keys (#5073) 2026-01-09 09:41:27 -05:00
Vlad
684fc0d2a2 [management] fix the issue with duplicated peers with the same key (#5053) 2026-01-09 11:49:26 +01:00
Viktor Liu
0ad0c81899 [client] Reorder userspace ACL checks to fail faster for better performance (#4226) 2026-01-09 09:13:04 +01:00
Viktor Liu
e8863fbb55 [client] Add non-root ICMP support to userspace firewall forwarder (#4792) 2026-01-09 02:53:37 +08:00
Zoltan Papp
9c9d8e17d7 Revert "Revert "[relay] Update GO version and QUIC version (#4736)" (#5055)" (#5071)
This reverts commit 24df442198.
2026-01-08 18:58:22 +01:00
Diego Noguês
fb71b0d04b [infrastructure] fix: disable Caddy debug (#5067) 2026-01-08 12:49:45 +01:00
Maycon Santos
ab7d6b2196 [misc] add new getting started to release (#5057) 2026-01-08 12:12:50 +01:00
Maycon Santos
9c5b2575e3 [misc] add embedded provider support metrics
count local vs idp users if embedded
2026-01-08 12:12:19 +01:00
Bethuel Mmbaga
00e2689ffb [management] Fix race condition in experimental network map when deleting account (#5064) 2026-01-08 14:10:09 +03:00
Misha Bragin
cf535f8c61 [management] Fix role change in transaction and update readme (#5060) 2026-01-08 12:07:59 +01:00
Maycon Santos
24df442198 Revert "[relay] Update GO version and QUIC version (#4736)" (#5055)
This reverts commit 8722b79799.
2026-01-07 19:02:20 +01:00
Zoltan Papp
8722b79799 [relay] Update GO version and QUIC version (#4736)
- Go 1.25.5
- QUIC 0.55.0
2026-01-07 16:30:29 +01:00
Vlad
afcdef6121 [management] add ssh authorized users to network map cache (#5048) 2026-01-07 15:53:18 +01:00
Zoltan Papp
12a7fa24d7 Add support for disabling eBPF WireGuard proxy via environment variable (#5047) 2026-01-07 15:34:52 +01:00
Zoltan Papp
6ff9aa0366 Refactor SSH server to manage listener lifecycle and expose active address via Addr method. (#5036) 2026-01-07 15:34:26 +01:00
Misha Bragin
e586c20e36 [management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)
Embed Dex as a built-in IdP to simplify self-hosting setup.
Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management.

more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
2026-01-07 14:52:32 +01:00
Pascal Fischer
5393ad948f [management] fix nil handling for extra settings (#5049) 2026-01-07 13:05:39 +01:00
Bethuel Mmbaga
20d6beff1b [management] Increment network serial on peer update (#5051)
Increment the serial on peer update and prevent double serial increments and account updates when updating a user while there are peers set to expire
2026-01-07 14:59:49 +03:00
Bethuel Mmbaga
d35b7d675c [management] Refactor integrated peer deletion (#5042) 2026-01-07 14:00:39 +03:00
Viktor Liu
f012fb8592 [client] Add port forwarding to ssh proxy (#5031)
* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
2026-01-07 12:18:04 +08:00
Vlad
7142d45ef3 [management] network map builder concurrent batch processing for peer updates (#5040) 2026-01-06 19:25:55 +01:00
Dennis Schridde
9bd578d4ea Fix ui-post-install.sh to use the full username (#4809)
Fixes #4808 by extracting the full username by:

- Get PID using pgrep
- Get UID from PID using /proc/${PID}/loginuid
- Get user name from UID using id
Also replaces "complex" pipe from ps to sed with a (hopefully) "simpler" (as in requiring less knowledge about the arguments of ps and regexps) invocation of cat and id.
2026-01-06 11:36:19 +01:00
Pascal Fischer
f022e34287 [shared] allow setting a user agent for the rest client (#5037) 2026-01-06 10:52:36 +01:00
Bethuel Mmbaga
7bb4fc3450 [management] Refactor integrated peer validator (#5035) 2026-01-05 20:55:22 +03:00
Maycon Santos
07856f516c [client] Fix/stuck connecting when can't access api.netbird.io (#5033)
- Connect on daemon start only if the file existed before
- fixed a bug that happened when the default profile config was removed, which would recreate it and reset the active profile to the default.
2026-01-05 13:53:17 +01:00
Zoltan Papp
08b782d6ba [client] Fix update download url (#5023) 2026-01-03 20:05:38 +03:00
Maycon Santos
80a312cc9c [client] add verbose flag for free ad tests (#5021)
add verbose flag for free ad tests
2026-01-03 11:32:41 +01:00
Zoltan Papp
9ba067391f [client] Fix semaphore slot leaks (#5018)
- Remove WaitGroup, make SemaphoreGroup a pure semaphore
- Make Add() return error instead of silently failing on context cancel
- Remove context parameter from Done() to prevent slot leaks
- Fix missing Done() call in conn.go error path
2026-01-03 09:10:02 +01:00
Pascal Fischer
7ac65bf1ad [management] Fix/delete groups without lock (#5012) 2025-12-31 11:53:20 +01:00
Zoltan Papp
2e9c316852 Fix UI stuck in "Connecting" state when daemon reports "Connected" status. (#5014)
The UI can get stuck showing "Connecting" status even after the daemon successfully connects and reports "Connected" status. This occurs because the condition to update the UI to "Connected" state checks the wrong flag.
2025-12-31 11:50:43 +01:00
shuuri-labs
96cdd56902 Feat/add support for forcing device auth flow on ios (#4944)
* updates to client file writing

* numerous

* minor

* - Align OnLoginSuccess behavior with Android (only call on nil error)
- Remove verbose debug logging from WaitToken in device_flow.go
- Improve TUN FD=0 fallback comments and warning messages
- Document why config save after login differs from Android

* Add nolint directive for staticcheck SA1029 in login.go

* Fix CodeRabbit review issues for iOS/tvOS SDK

- Remove goroutine from OnLoginSuccess callback, invoke synchronously
- Stop treating PermissionDenied as success, propagate as permanent error
- Replace context.TODO() with bounded timeout context (30s) in RequestAuthInfo
- Handle DirectUpdateOrCreateConfig errors in IsLoginRequired and LoginForMobile
- Add permission enforcement to DirectUpdateOrCreateConfig for existing configs
- Fix variable shadowing in device_ios.go where err was masked by := in else block

* Address additional CodeRabbit review issues for iOS/tvOS SDK

- Make tunFd == 0 a hard error with exported ErrInvalidTunnelFD (remove dead fallback code)
- Apply defaults in ConfigFromJSON to prevent partially-initialized configs
- Add nil guards for listener/urlOpener interfaces in public SDK entry points
- Reorder config save before OnLoginSuccess to prevent teardown race
- Add explanatory comment for urlOpener.Open goroutine

* Make urlOpener.Open() synchronous in device auth flow
2025-12-30 16:41:36 +00:00
Misha Bragin
9ed1437442 Add DEX IdP Support (#4949) 2025-12-30 07:42:34 -05:00
Pascal Fischer
a8604ef51c [management] filter own peer when having a group to peer policy to themself (#4956) 2025-12-30 10:49:43 +01:00
Nicolas Henneaux
d88e046d00 fix(router): nft tables limit number of peers source (#4852)
* fix(router): nft tables limit number of peers source batching them, failing at 3277 prefixes on nftables v1.0.9 with Ubuntu 24.04.3 LTS,  6.14.0-35-generic #35~24.04.1-Ubuntu

* fix(router): nft tables limit number of prefixes on ipSet creation
2025-12-30 10:48:17 +01:00
Pascal Fischer
1d2c7776fd [management] apply login filter only for setup key peers (#4943) 2025-12-30 10:46:00 +01:00
Haruki Hasegawa
4035f07248 [client] Fix Advanced Settings not opening on Windows with Japanese locale (#4455) (#4637)
The Fyne framework does not support TTC font files.
Use the default system font (Segoe UI) instead, so Windows can
automatically fall back to a Japanese font when needed.
2025-12-30 10:36:12 +01:00
Zoltan Papp
ef2721f4e1 Filter out own peer from remote peers list during peer updates. (#4986) 2025-12-30 10:29:45 +01:00
Louis Li
e11970e32e [client] add reset for management backoff (#4935)
Reset client management grpc client backoff after successful connected to management API.

Current Situation:
If the connection duration exceeds MaxElapsedTime, when the connection is interrupted, the backoff fails immediately due to timeout and does not actually perform a retry.
2025-12-30 08:37:49 +01:00
Maycon Santos
38f9d5ed58 [infra] Preset signal port on templates (#5004)
When passing certificates to signal, it will select port 443 when no port is supplied. This changes forces port 80.
2025-12-29 18:07:06 +03:00
Pascal Fischer
b6a327e0c9 [management] fix scanning authorized user on policy rule (#5002) 2025-12-29 15:03:16 +01:00
Zoltan Papp
67f7b2404e [client, management] Feature/ssh fine grained access (#4969)
Add fine-grained SSH access control with authorized users/groups
2025-12-29 12:50:41 +01:00
Zoltan Papp
73201c4f3e Add conditional checks for FreeBSD diff file generation in release workflow (#5001) 2025-12-29 12:47:38 +01:00
Carlos Hernandez
33d1761fe8 Apply DNS host config on change only (#4695)
Adds a per-instance uint64 hash to DefaultServer to detect identical merged host DNS configs (including extra domains). applyHostConfig computes and compares the hash, skips applying if unchanged, treats hash errors as a fail-safe (proceed to apply), and updates the stored hash only after successful hashing and apply.
2025-12-29 12:43:57 +01:00
August
aa914a0f26 [docs] Fix broken image link (#4876) 2025-12-24 22:06:35 +05:00
Maycon Santos
ab6a9e85de [misc] Use new sign pipelines 0.1.0 (#4993) 2025-12-24 22:03:14 +05:00
Maycon Santos
d3b123c76d [ci] Add FreeBSD port release job to GitHub Actions (#4916)
adds a job that produces new freebsd release files
2025-12-24 11:22:33 +01:00
Viktor Liu
fc4932a23f [client] Fix Linux UI flickering on state updates (#4886) 2025-12-24 11:06:13 +01:00
Zoltan Papp
b7e98acd1f [client] Android profile switch (#4884)
Expose the profile-manager service for Android. Logout was not part of the manager service implementation. In the future, I recommend moving this logic there.
2025-12-22 22:09:05 +01:00
Maycon Santos
433bc4ead9 [client] lookup for management domains using an additional timeout (#4983)
in some cases iOS and macOS may be locked when looking for management domains during network changes

This change introduce an additional timeout on top of the context call
2025-12-22 20:04:52 +01:00
Zoltan Papp
011cc81678 [client, management] auto-update (#4732) 2025-12-19 19:57:39 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
Pascal Fischer
c29bb1a289 [management] use xid as request id for logging (#4955) 2025-12-16 14:02:37 +01:00
474 changed files with 50400 additions and 6191 deletions

View File

@@ -1,15 +1,15 @@
FROM golang:1.23-bullseye
FROM golang:1.25-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\
gettext-base=0.21-4 \
iptables=1.8.7-1 \
libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
gettext-base=0.21-12 \
iptables=1.8.9-2 \
libgl1-mesa-dev=22.3.6-1+deb12u1 \
xorg-dev=1:7.7+23 \
libayatana-appindicator3-dev=0.5.92-1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@v0.18.1
&& go install -v golang.org/x/tools/gopls@latest
WORKDIR /app

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
@@ -39,7 +39,7 @@ jobs:
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...

View File

@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.24-alpine \
golang:1.25-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m ./relay/... ./shared/relay/...
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_signal:
name: "Signal / Unit"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
golangci:
strategy:
@@ -52,7 +52,10 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
args: --timeout=12m --out-format colored-line-number
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=12m

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.23"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,6 +19,100 @@ concurrency:
cancel-in-progress: true
jobs:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release:
runs-on: ubuntu-latest-m
env:

View File

@@ -243,6 +243,7 @@ jobs:
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose logs
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db

View File

@@ -14,6 +14,9 @@ jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
env:
GOOS: js
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -24,16 +27,14 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
skip-save-cache: true
cache-invalidation-interval: 0
working-directory: ./client
continue-on-error: true
js_build:

1
.gitignore vendored
View File

@@ -31,3 +31,4 @@ infrastructure_files/setup-*.env
.DS_Store
vendor/
/netbird
client/netbird-electron/

View File

@@ -1,139 +1,124 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
revive:
rules:
- name: exported
severity: warning
disabled: false
arguments:
- "checkPrivateReceivers"
- "sayRepetitiveInsteadOfStutters"
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
version: "2"
linters:
disable-all: true
default: none
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- dupword # dupword checks for duplicate words in the source code
- durationcheck # durationcheck checks for two durations multiplied together
- forbidigo # forbidigo forbids identifiers
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gosec # inspects source code for security problems
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
- misspell # misspess finds commonly misspelled English words in comments
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
- bodyclose
- dupword
- durationcheck
- errcheck
- forbidigo
- gocritic
- gosec
- govet
- ineffassign
- mirror
- misspell
- nilerr
- nilnil
- predeclared
- revive
- sqlclosecheck
- staticcheck
- unused
- wastedassign
settings:
errcheck:
check-type-assertions: false
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
gosec:
includes:
- G101
- G103
- G104
- G106
- G108
- G109
- G110
- G111
- G201
- G202
- G203
- G301
- G302
- G303
- G304
- G305
- G306
- G307
- G403
- G502
- G503
- G504
- G601
- G602
govet:
enable:
- nilness
enable-all: false
revive:
rules:
- name: exported
arguments:
- checkPrivateReceivers
- sayRepetitiveInsteadOfStutters
severity: warning
disabled: false
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- linters:
- forbidigo
path: management/cmd/root\.go
- linters:
- forbidigo
path: signal/cmd/root\.go
- linters:
- unused
path: sharedsock/filter\.go
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- gosec
- mirror
path: test\.go
- linters:
- nilnil
path: mock\.go
- linters:
- staticcheck
text: grpc.DialContext is deprecated
- linters:
- staticcheck
text: grpc.WithBlock is deprecated
- linters:
- staticcheck
text: "QF1001"
- linters:
- staticcheck
text: "QF1008"
- linters:
- staticcheck
text: "QF1012"
paths:
- third_party$
- builtin$
- examples$
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
# allow fmt
- path: management/cmd/root\.go
linters: forbidigo
- path: signal/cmd/root\.go
linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -713,8 +713,10 @@ checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh

View File

@@ -38,6 +38,11 @@
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
@@ -55,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features
@@ -85,7 +90,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- **Public domain** name pointing to the VM.
**Software requirements:**
@@ -98,7 +103,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
@@ -113,7 +118,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -59,7 +59,6 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
@@ -68,18 +67,16 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: platformFiles.ConfigurationFilePath(),
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -87,15 +84,20 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -122,16 +124,22 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -149,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources

View File

@@ -3,15 +3,7 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
err, _ = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
go urlOpener.OnLoginSuccess()
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -0,0 +1,257 @@
//go:build android
package android
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
// Android-specific config filename (different from desktop default.json)
defaultConfigFilename = "netbird.cfg"
// Subdirectory for non-default profiles (must match Java Preferences.java)
profilesSubdir = "profiles"
// Android uses a single user context per app (non-empty username required by ServiceManager)
androidUsername = "android"
)
// Profile represents a profile for gomobile
type Profile struct {
Name string
IsActive bool
}
// ProfileArray wraps profiles for gomobile compatibility
type ProfileArray struct {
items []*Profile
}
// Length returns the number of profiles
func (p *ProfileArray) Length() int {
return len(p.items)
}
// Get returns the profile at index i
func (p *ProfileArray) Get(i int) *Profile {
if i < 0 || i >= len(p.items) {
return nil
}
return p.items[i]
}
/*
/data/data/io.netbird.client/files/ ← configDir parameter
├── netbird.cfg ← Default profile config
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
└── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
// It wraps the internal profilemanager to provide Android-specific behavior
type ProfileManager struct {
configDir string
serviceMgr *profilemanager.ServiceManager
}
// NewProfileManager creates a new profile manager for Android
func NewProfileManager(configDir string) *ProfileManager {
// Set the default config path for Android (stored in root configDir, not profiles/)
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
// Set global paths for Android
profilemanager.DefaultConfigPathDir = configDir
profilemanager.DefaultConfigPath = defaultConfigPath
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
// Create ServiceManager with profiles/ subdirectory
// This avoids modifying the global ConfigDirOverride for profile listing
profilesDir := filepath.Join(configDir, profilesSubdir)
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
return &ProfileManager{
configDir: configDir,
serviceMgr: serviceMgr,
}
}
// ListProfiles returns all available profiles
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
// Convert internal profiles to Android Profile type
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
Name: p.Name,
IsActive: p.IsActive,
})
}
return &ProfileArray{items: profiles}, nil
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
config, err := profilemanager.ReadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to read profile config: %w", err)
}
// Clear authentication by removing private key and SSH key
config.PrivateKey = ""
config.SSHKey = ""
// Save config using internal profilemanager
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -136,6 +134,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
//nolint
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
@@ -220,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
cpuProfilingStarted := false
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = true
defer func() {
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
}
}
}()
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = false
}
}
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -301,25 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -378,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config,
StatusRecorder: recorder,
SyncResponse: syncResponse,
LogFile: logFilePath,
LogPath: logFilePath,
CPUProfile: nil,
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -7,7 +7,6 @@ import (
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -81,6 +80,7 @@ var loginCmd = &cobra.Command{
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
@@ -206,6 +206,7 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
@@ -275,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
}
jwtToken := ""
@@ -298,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
err, _ = authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return nil
@@ -342,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -1,5 +1,4 @@
//go:build pprof
// +build pprof
package cmd

View File

@@ -85,6 +85,9 @@ var (
// Execute executes the root command.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute()
}
@@ -387,6 +390,7 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -0,0 +1,176 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -0,0 +1,276 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

21
client/cmd/signer/main.go Normal file
View File

@@ -0,0 +1,21 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,220 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,74 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err
}
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
if err := validateDestinationPort(remoteAddr); err != nil {
return fmt.Errorf("invalid remote address: %w", err)
}
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err
}
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
if err := validateDestinationPort(localAddr); err != nil {
return fmt.Errorf("invalid local address: %w", err)
}
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil
}
// validateDestinationPort checks that the destination address has a valid port.
// Port 0 is only valid for bind addresses (where the OS picks an available port),
// not for destination addresses where we need to connect.
func validateDestinationPort(addr string) error {
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
return nil
}
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("parse address %s: %w", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port %s: %w", portStr, err)
}
if port == 0 {
return fmt.Errorf("port 0 is not valid for destination address")
}
if port < 0 || port > 65535 {
return fmt.Errorf("port %d out of range (1-65535)", port)
}
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {

View File

@@ -99,17 +99,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
statusOutputString = outputInformationHolder.FullDetailSummary()
case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
statusOutputString, err = outputInformationHolder.JSON()
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
statusOutputString, err = outputInformationHolder.YAML()
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
}
if err != nil {
@@ -124,6 +124,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
@@ -89,9 +90,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
@@ -100,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -118,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -127,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -197,10 +197,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
@@ -216,6 +216,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

13
client/cmd/update.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -0,0 +1,75 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -16,10 +16,12 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
@@ -38,6 +40,7 @@ type Client struct {
setupKey string
jwtToken string
connect *internal.ConnectClient
recorder *peer.Status
}
// Options configures a new Client.
@@ -66,6 +69,8 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
}
// validateCredentials checks that exactly one credential type is provided
@@ -134,6 +139,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -161,26 +167,40 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
if c.connect != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()
@@ -197,6 +217,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
c.connect = client
c.cancel = cancel
return nil
}
@@ -211,17 +232,23 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted
}
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1)
connect := c.connect
go func() {
done <- c.connect.Stop()
done <- connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
c.connect = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
c.connect = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
@@ -315,6 +342,62 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -386,11 +386,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is 0.0.0.0
if ip.IsUnspecified() {
matchByIP = false
}
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {

View File

@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.cleanupNoTrackChain(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
// Egress rules: match outgoing loopback UDP packets
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
return fmt.Errorf("add output sport notrack rule: %w", err)
}
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
return fmt.Errorf("add output dport notrack rule: %w", err)
}
// Ingress rules: match incoming loopback UDP packets
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
}
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
}
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("create chain: %w", err)
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add prerouting jump rule: %w", err)
}
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
return nil
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("clear and delete chain: %w", err)
}
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex int = -1, -1
var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)

View File

@@ -168,6 +168,10 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -48,8 +49,10 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
router *router
aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
}
// Create nftables firewall manager
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.Flush()
if err := m.aclManager.Flush(); err != nil {
return err
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
return nil
}
// AddDNATRule adds a DNAT rule
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
return fmt.Errorf("notrack chains not initialized")
}
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
loopback := []byte{127, 0, 0, 1}
// Egress rules: match outgoing loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
// Ingress rules: match incoming loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
&expr.Counter{},
&expr.Notrack{},
},
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush notrack rules: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRaw,
})
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawPrerouting,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRaw,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush chain creation: %w", err)
}
return nil
}
func (m *Manager) refreshNoTrackChains() error {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
tableName := getTableName()
for _, c := range chains {
if c.Table.Name != tableName {
continue
}
switch c.Name {
case chainNameRawOutput:
m.notrackOutputChain = c
case chainNameRawPrerouting:
m.notrackPreroutingChain = c
}
}
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex int = -1, -1
var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
@@ -208,11 +208,13 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
if lookup.SetName == "accept-http" {
switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true
} else if lookup.SetName == "deny-http" {
case "deny-http":
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
@@ -222,9 +224,10 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
if verdict.Kind == expr.VerdictAccept {
switch verdict.Kind {
case expr.VerdictAccept:
action = "ACCEPT"
} else if verdict.Kind == expr.VerdictDrop {
case expr.VerdictDrop:
action = "DROP"
}
}
@@ -386,6 +389,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
const octet2Count = 25
const octet3Count = 255
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
for i := 1; i < octet2Count; i++ {
for j := 1; j < octet3Count; j++ {
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")

View File

@@ -48,9 +48,11 @@ const (
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const refreshRulesMapError = "refresh rules map: %w"
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
@@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
var subEnd int
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd = min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}

View File

@@ -29,7 +29,7 @@ import (
)
const (
layerTypeAll = 0
layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
@@ -262,10 +262,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering(
@@ -439,19 +436,7 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort
r.dPort = dPort
switch proto {
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
r.protoLayer = protoToLayer(proto, r.ipLayer)
m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet
@@ -496,16 +481,17 @@ func (m *Manager) addRouteFiltering(
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort,
dstPort: dPort,
action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
@@ -584,6 +570,14 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -795,7 +789,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
var sum = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
@@ -945,7 +939,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked {
_, pnum := getProtocolFromPacket(d)
pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -1010,20 +1004,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false
}
proto, pnum := getProtocolFromPacket(d)
protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: pnum,
Protocol: proto,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
@@ -1052,16 +1048,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true
}
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return firewall.ProtocolTCP, nftypes.TCP
return nftypes.TCP
case layers.LayerTypeUDP:
return firewall.ProtocolUDP, nftypes.UDP
return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP, nftypes.ICMP
return nftypes.ICMP
default:
return firewall.ProtocolALL, nftypes.ProtocolUnknown
return nftypes.ProtocolUnknown
}
}
@@ -1233,19 +1246,30 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
return nil, false
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
@@ -1264,21 +1288,8 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break
}
}
if !sourceMatched {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
return sourceMatched
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
}
}
}

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)
})
}
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
}
})
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -2,6 +2,7 @@ package forwarder
import (
"fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -16,7 +17,7 @@ type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
mtu atomic.Uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool {
}
func (e *endpoint) MTU() uint32 {
return e.mtu
return e.mtu.Load()
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}
func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}
func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}
type epID stack.TransportEndpointID
func (i epID) String() string {

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@@ -35,14 +36,16 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
endpoint.mtu.Store(uint32(mtu))
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err)
@@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
}
receiveWindow := defaultReceiveWindow
@@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
@@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort,
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -2,8 +2,11 @@ package forwarder
import (
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"runtime"
"time"
"github.com/google/uuid"
@@ -14,30 +17,95 @@ import (
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
// For Echo Requests, send and wait for response
if icmpHdr.Type() == header.ICMPv4Echo {
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
if !f.hasRawICMPAccess {
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
return true
}
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return true
}
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
return nil, fmt.Errorf("create ICMP socket: %w", err)
}
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
}
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
}
defer func() {
if err := conn.Close(); err != nil {
@@ -45,38 +113,22 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
response := make([]byte, f.endpoint.mtu)
response := make([]byte, f.endpoint.mtu.Load())
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
@@ -85,31 +137,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
return 0
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
return f.injectICMPReply(id, response[:n])
}
// sendICMPEvent stores flow events for ICMP packets
@@ -152,3 +180,95 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
f.flowLogger.StoreEvent(fields)
}
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
// This is used as a fallback when raw socket access is not available.
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
// Returns the size of the injected packet.
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyICMPHdr := header.ICMPv4(replyICMP)
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
return f.injectICMPReply(id, replyICMP)
}
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
// Returns the total size of the injected packet, or 0 if injection failed.
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
return 0
}
return len(fullPacket)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
@@ -131,10 +132,10 @@ func (f *udpForwarder) cleanup() {
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
return false
}
id := r.ID()
@@ -144,7 +145,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return
return true
}
flowID := uuid.New()
@@ -162,7 +163,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
return false
}
// Create wait queue for blocking syscalls
@@ -173,10 +174,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
return false
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
inConn := gonet.NewUDPConn(&wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
@@ -199,7 +200,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
return true
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
@@ -208,6 +209,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
return true
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
@@ -348,7 +350,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
}
func isTimeout(err error) bool {

View File

@@ -130,6 +130,7 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}

View File

@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
_ = mapManager.localIPs[ip.String()]
}
})
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
_ = mapManager.localIPs[ip.String()]
}
})
}

View File

@@ -168,6 +168,15 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
}
}
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {

View File

@@ -234,9 +234,10 @@ func TestInboundPortDNATNegative(t *testing.T) {
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
switch tc.protocol {
case layers.IPProtocolTCP:
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
case layers.IPProtocolUDP:
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix
dstSet firewall.Set
destinations []netip.Prefix
proto firewall.Protocol
protoLayer gopacket.LayerType
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
strId := string(id)
if id == nil {

View File

@@ -0,0 +1,169 @@
package bind
import (
"errors"
"net"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
// to the appropriate connection based on the destination address.
// ReadFrom is not used in the hot path - ICEBind receives packets via
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
readFromWarn sync.Once
}
// NewDualStackPacketConn creates a new dual-stack packet connection.
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from the available connection (preferring IPv4).
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
d.readFromWarn.Do(func() {
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
})
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type.
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections.
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection if available,
// otherwise the IPv6 connection.
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections.
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections.
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections.
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -0,0 +1,119 @@
package bind
import (
"net"
"testing"
)
var (
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
payload = make([]byte, 1200)
)
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = conn.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
ds := NewDualStackPacketConn(conn, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn.Close()
ds := NewDualStackPacketConn(nil, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
addrs := []net.Addr{ipv4Addr, ipv6Addr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, addrs[i&1])
}
}

View File

@@ -0,0 +1,191 @@
package bind
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
ipv4Conn := &mockPacketConn{network: "udp4"}
ipv6Conn := &mockPacketConn{network: "udp6"}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
tests := []struct {
name string
addr *net.UDPAddr
wantSocket string
}{
{
name: "IPv4 address",
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 address",
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantSocket: "udp6",
},
{
name: "IPv4-mapped IPv6 goes to IPv4",
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv4 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
wantSocket: "udp6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipv4Conn.writeCount = 0
ipv6Conn.writeCount = 0
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
require.NoError(t, err)
assert.Equal(t, 4, n)
if tt.wantSocket == "udp4" {
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
} else {
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
}
})
}
}
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
// IPv4 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.NoError(t, err)
// IPv6 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv6 connection")
}
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
// IPv6 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.NoError(t, err)
// IPv4 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv4 connection")
}
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
// only reads from one socket (IPv4 preferred). This is fine because the actual
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
ipv4Conn := &mockPacketConn{
network: "udp4",
readData: []byte("from ipv4"),
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
}
ipv6Conn := &mockPacketConn{
network: "udp6",
readData: []byte("from ipv6"),
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
buf := make([]byte, 100)
n, addr, err := dualStack.ReadFrom(buf)
require.NoError(t, err)
// reads from IPv4 (preferred) - this is expected behavior
assert.Equal(t, "from ipv4", string(buf[:n]))
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
}
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
tests := []struct {
name string
ipv4 net.PacketConn
ipv6 net.PacketConn
wantAddr net.Addr
}{
{
name: "both available returns IPv4",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv4Addr,
},
{
name: "IPv4 only",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: nil,
wantAddr: ipv4Addr,
},
{
name: "IPv6 only",
ipv4: nil,
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv6Addr,
},
{
name: "neither returns nil",
ipv4: nil,
ipv6: nil,
wantAddr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
})
}
}
// mock
type mockPacketConn struct {
network string
writeCount int
readData []byte
readAddr net.Addr
localAddr net.Addr
}
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if m.readData != nil {
return copy(b, m.readData), m.readAddr, nil
}
return 0, nil, nil
}
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
m.writeCount++
return len(b), nil
}
func (m *mockPacketConn) Close() error { return nil }
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@@ -14,7 +14,6 @@ import (
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -27,8 +26,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -58,6 +57,8 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
@@ -103,6 +104,12 @@ func (s *ICEBind) Close() error {
close(s.closedChan)
s.muUDPMux.Lock()
s.ipv4Conn = nil
s.ipv6Conn = nil
s.udpMux = nil
s.muUDPMux.Unlock()
return s.StdNetBind.Close()
}
@@ -160,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
// Detect IPv4 vs IPv6 from connection's local address
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
s.ipv4Conn = conn
} else {
s.ipv6Conn = conn
}
s.createOrUpdateMux()
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
@@ -180,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
//nolint:staticcheck
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
@@ -207,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
continue
}
sizes[i] = msg.N
@@ -233,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
}
// createOrUpdateMux creates or updates the UDP mux with the available connections.
// Must be called with muUDPMux held.
func (s *ICEBind) createOrUpdateMux() {
var muxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
muxConn = NewDualStackPacketConn(
nbnet.WrapPacketConn(s.ipv4Conn),
nbnet.WrapPacketConn(s.ipv6Conn),
)
case s.ipv4Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
case s.ipv6Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
default:
return
}
// Don't close the old mux - it doesn't own the underlying connections.
// The sockets are managed by WireGuard's StdNetBind, not by us.
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
@@ -245,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
s.muUDPMux.Lock()
mux := s.udpMux
s.muUDPMux.Unlock()
if mux != nil {
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
log.Warnf("failed to handle STUN packet: %v", muxErr)
}
}
buffers[i] = []byte{}

View File

@@ -0,0 +1,324 @@
package bind
import (
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, ipv6Conn := createDualStackConns(t)
defer ipv4Conn.Close()
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
pool := createMsgPool()
// Simulate wireguard-go calling CreateReceiverFn for IPv4
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
require.NotNil(t, ipv4RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
iceBind.muUDPMux.Unlock()
// Simulate wireguard-go calling CreateReceiverFn for IPv6
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
require.NotNil(t, ipv6RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer ipv4Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn)
assert.Nil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.Nil(t, iceBind.ipv4Conn)
assert.NotNil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
// with peers on different address families through the same DualStackPacketConn.
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
// two "remote peers" listening on different address families
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
// our local dual-stack connection
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
// send to both peers
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
require.NoError(t, err)
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
require.NoError(t, err)
// verify IPv4 peer got its packet from the IPv4 socket
buf := make([]byte, 100)
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := ipv4Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv4", string(buf[:n]))
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
// verify IPv6 peer got its packet from the IPv6 socket
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err = ipv6Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv6", string(buf[:n]))
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
}
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
const packetsPerFamily = 500
ipv4Received := make(chan string, packetsPerFamily)
ipv6Received := make(chan string, packetsPerFamily)
startGate := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv4Peer.ReadFrom(buf)
if err != nil {
return
}
ipv4Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv6Peer.ReadFrom(buf)
if err != nil {
return
}
ipv6Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
}
}()
close(startGate)
time.AfterFunc(5*time.Second, func() {
_ = ipv4Peer.SetReadDeadline(time.Now())
_ = ipv6Peer.SetReadDeadline(time.Now())
})
wg.Wait()
close(ipv4Received)
close(ipv6Received)
ipv4Count := 0
for pkt := range ipv4Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
ipv4Count++
}
ipv6Count := 0
for pkt := range ipv6Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
tests := []struct {
name string
network string
addr string
wantIPv4 bool
}{
{"IPv4 any", "udp4", "0.0.0.0:0", true},
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
{"IPv6 any", "udp6", "[::]:0", false},
{"IPv6 loopback", "udp6", "[::1]:0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
require.NoError(t, err)
conn, err := net.ListenUDP(tt.network, addr)
if err != nil {
t.Skipf("%s not available: %v", tt.network, err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
isIPv4 := localAddr.IP.To4() != nil
assert.Equal(t, tt.wantIPv4, isIPv4)
})
}
}
// helpers
func setupICEBind(t *testing.T) *ICEBind {
t.Helper()
transportNet, err := stdnet.NewNet()
require.NoError(t, err)
address := wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
t.Helper()
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
ipv4Conn.Close()
t.Skipf("IPv6 not available: %v", err)
}
return ipv4Conn, ipv6Conn
}
func createMsgPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, 1)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, 40)
}
return &msgs
},
}
}
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
t.Helper()
udpAddr, err := net.ResolveUDPAddr(network, addr)
require.NoError(t, err)
conn, err := net.ListenUDP(network, udpAddr)
require.NoError(t, err)
return conn
}

View File

@@ -3,8 +3,22 @@ package configurer
import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
// This is a shared helper used by both kernel and userspace configurers.
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
return wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
PublicKey: peerKey,
PresharedKey: &psk,
UpdateOnly: updateOnly,
}},
}
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {

View File

@@ -15,8 +15,6 @@ import (
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.configure(cfg)
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
PresharedKey: [32]byte(p.PresharedKey),
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint

View File

@@ -22,17 +22,16 @@ import (
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.device.IpcSet(toWgUserspaceString(cfg))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.Remove {
sb.WriteString("remove=true\n")
}
if p.UpdateOnly {
sb.WriteString("update_only=true\n")
}
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
}
return sb.String()
}
@@ -543,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
host, portStr, err := net.SplitHostPort(val)
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
currentPeer.PresharedKey = true
if pskKey, err := hexToWireguardKey(val); err == nil {
currentPeer.PresharedKey = [32]byte(pskKey)
}
}
}
}

View File

@@ -12,7 +12,7 @@ type Peer struct {
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
PresharedKey [32]byte
}
type Stats struct {

View File

@@ -1,9 +1,7 @@
//go:build ios
// +build ios
package device
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
@@ -45,10 +43,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
// This typically means the Swift code couldn't find the utun control socket.
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
var tunDevice tun.Device
var err error
// Validate the tunnel file descriptor.
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
// A value of 0 means the Swift code couldn't find the utun control socket
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
// tvOS SDK headers). This is a hard error - there's no viable fallback
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
if t.tunFd == 0 {
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
return nil, ErrInvalidTunnelFD
}
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
var dupTunFd int
dupTunFd, err = unix.Dup(t.tunFd)
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return nil, err
@@ -60,7 +79,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
_ = unix.Close(dupTunFd)
return nil, err
}
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd)

View File

@@ -17,6 +17,7 @@ type WGConfigurer interface {
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)

View File

@@ -50,6 +50,7 @@ func ValidateMTU(mtu uint16) error {
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
Free() error
}
@@ -80,6 +81,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetProxyPort returns the proxy port used by the WireGuard proxy.
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
func (w *WGIface) GetProxyPort() uint16 {
return w.wgProxyFactory.GetProxyPort()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
@@ -297,6 +304,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
// SetPresharedKey sets or updates the preshared key for a peer.
// If updateOnly is true, only updates existing peer; if false, creates or updates.
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)

View File

@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.wgCurrentUsed = ep
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
@@ -212,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}

View File

@@ -8,8 +8,6 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
@@ -26,13 +24,10 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
proxyPort int
mtu uint16
ebpfManager ebpfMgr.Manager
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
lastUsedPort uint16
rawConn net.PacketConn
rawConnIPv4 net.PacketConn
rawConnIPv6 net.PacketConn
conn transport.UDPConn
ctx context.Context
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
proxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.proxyPort = proxyPort
// Prepare IPv4 raw socket (required)
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
if err != nil {
return err
}
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
// Prepare IPv6 raw socket (optional)
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
if err != nil {
return err
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
}
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
if err != nil {
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
}
if p.rawConnIPv6 != nil {
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
}
}
return err
}
addr := net.UDPAddr{
Port: wgPorxyPort,
Port: proxyPort,
IP: net.ParseIP(loopbackAddr),
}
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
log.Infof("local wg proxy listening on: %d", proxyPort)
return nil
}
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
result = multierror.Append(result, err)
}
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
if p.rawConnIPv4 != nil {
if err := p.rawConnIPv4.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if p.rawConnIPv6 != nil {
if err := p.rawConnIPv6.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// GetProxyPort returns the proxy listening port.
func (p *WGEBPFProxy) GetProxyPort() uint16 {
return uint16(p.proxyPort)
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -216,34 +241,3 @@ generatePort:
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,12 +10,89 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool
pausedCond *sync.Cond
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
return err
p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted {
p.isStarted = true
@@ -91,10 +188,32 @@ func (p *ProxyWrapper) Pause() {
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
@@ -136,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
p.pausedCond.L.Unlock()
if err != nil {
@@ -162,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
return n, nil
}
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -3,12 +3,19 @@
package wgproxy
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
const (
envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
)
type KernelFactory struct {
wgPort int
mtu uint16
@@ -22,6 +29,12 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
mtu: mtu,
}
if isEBPFDisabled() {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
return f
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
@@ -41,9 +54,30 @@ func (w *KernelFactory) GetProxy() Proxy {
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
func (w *KernelFactory) GetProxyPort() uint16 {
if w.ebpfProxy == nil {
return 0
}
return w.ebpfProxy.GetProxyPort()
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}
func isEBPFDisabled() bool {
val := os.Getenv(envDisableEBPFWGProxy)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
return false
}
return disabled
}

View File

@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
func (w *USPFactory) GetProxyPort() uint16 {
return 0
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -8,43 +8,87 @@ import (
"os"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/client/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET, true)
}
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET6, false)
}
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
if isIPv4 {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
} else {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
}
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file: %v", closeErr)
}
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
// Close the original file to release the FD (net.FilePacketConn duplicates it)
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
}
return packetConn, nil
}

View File

@@ -0,0 +1,353 @@
//go:build linux && !android
package wgproxy
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
func compareUDPAddr(addr1, addr2 net.Addr) bool {
udpAddr1, ok1 := addr1.(*net.UDPAddr)
udpAddr2, ok2 := addr2.(*net.UDPAddr)
if !ok1 || !ok2 {
return addr1.String() == addr2.String()
}
// Compare IP and Port, ignoring zone
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
func TestRedirectAs_UDP_IPv6(t *testing.T) {
wgPort := 51853
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// testRedirectAs is a helper function that tests the RedirectAs functionality
// It verifies that:
// 1. Initial traffic from relay connection works
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
// 3. Multiple packets are correctly redirected with the new source address
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
t.Helper()
ctx := context.Background()
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
}
defer wgListener4.Close()
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
}
defer wgListener6.Close()
// Determine which listener to use based on the NetBird address IP version
// (this is where initial traffic will come from before RedirectAs is called)
var wgListener *net.UDPConn
if p2pEndpoint.IP.To4() == nil {
wgListener = wgListener6
} else {
wgListener = wgListener4
}
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // Random port
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
// Add TURN connection to proxy
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
// Start the proxy
proxy.Work()
// Phase 1: Test initial relay traffic
msgFromRelay := []byte("hello from relay")
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write to relay server: %v", err)
}
// Set read deadline to avoid hanging
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
buf := make([]byte, 1024)
n, _, err := wgListener4.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read from WireGuard listener: %v", err)
}
if n != len(msgFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
}
if string(buf[:n]) != string(msgFromRelay) {
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
}
// Phase 2: Redirect to p2p endpoint
proxy.RedirectAs(p2pEndpoint)
// Give the proxy a moment to process the redirect
time.Sleep(100 * time.Millisecond)
// Phase 3: Test redirected traffic
redirectedMessages := [][]byte{
[]byte("redirected message 1"),
[]byte("redirected message 2"),
[]byte("redirected message 3"),
}
for i, msg := range redirectedMessages {
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
}
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
}
// Verify message content
if string(buf[:n]) != string(msg) {
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
}
// Verify source address matches p2p endpoint (this is the key test)
// Use compareUDPAddr to ignore IPv6 zone IDs
if !compareUDPAddr(srcAddr, p2pEndpoint) {
t.Errorf("message %d: expected source address %s, got %s",
i+1, p2pEndpoint.String(), srcAddr.String())
}
}
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
ctx := context.Background()
// Create WireGuard listener
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create WireGuard listener: %v", err)
}
defer wgListener.Close()
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
proxy.Work()
// Test switching between multiple endpoints - using addresses in local subnet
endpoints := []*net.UDPAddr{
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
}
for i, endpoint := range endpoints {
proxy.RedirectAs(endpoint)
time.Sleep(100 * time.Millisecond)
msg := []byte("test message")
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
}
buf := make([]byte, 1024)
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
}
if !compareUDPAddr(srcAddr, endpoint) {
t.Errorf("endpoint %d: expected source %s, got %s",
i, endpoint.String(), srcAddr.String())
}
}
}

View File

@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {

View File

@@ -19,37 +19,56 @@ var (
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
localHostNetIPAddrV4 = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
// Create only the raw socket for the address family we need
var rawSocket net.PacketConn
var err error
var localHostAddr *net.IPAddr
if srcAddr.IP.To4() != nil {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
localHostAddr = localHostNetIPAddrV4
} else {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
localHostAddr = localHostNetIPAddrV6
}
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
if closeErr := rawSocket.Close(); closeErr != nil {
log.Warnf("failed to close raw socket: %v", closeErr)
}
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
}
return f, nil
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
// Check if source IP is IPv4 or IPv6
if srcAddr.IP.To4() != nil {
// IPv4
ipv4 := &layers.IPv4{
DstIP: localHostNetIPAddrV4.IP,
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: localHostNetIPAddrV6.IP,
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
err := udpH.SetNetworkLayerForChecksum(networkLayer)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}

View File

@@ -0,0 +1,499 @@
package auth
import (
"context"
"net/url"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
mutex sync.RWMutex
client *mgm.GrpcClient
config *profilemanager.Config
privateKey wgtypes.Key
mgmURL *url.URL
mgmTLSEnabled bool
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The connection is automatically recreated with backoff if it becomes disconnected
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
privateKey: myPrivateKey,
mgmURL: mgmURL,
mgmTLSEnabled: mgmTLSEnabled,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
a.mutex.Lock()
defer a.mutex.Unlock()
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
var supportsSSO bool
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
// Try PKCE flow first
_, err := a.getPKCEFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if PKCE is not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if Device flow is also not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// Neither PKCE nor Device flow is supported
supportsSSO = false
return nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return err
})
return supportsSSO, err
}
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
// This avoids creating a new connection to the management server
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
var flow OAuthFlow
var err error
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
if forceDeviceAuth {
flow, err = a.getDeviceFlow(client)
return err
}
// Try PKCE flow first
flow, err = a.getPKCEFlow(client)
if err != nil {
// If PKCE not supported, try Device flow
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
flow, err = a.getDeviceFlow(client)
return err
}
return err
}
return nil
})
return flow, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
}
needsLogin = false
return err
})
return needsLogin, err
}
// Login attempts to log in or register the client with the management server
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return err, false
}
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
} else if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
isAuthError = false
return nil
})
return err, isAuthError
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
// setSystemInfoFlags sets all configuration flags on the provided system info
func (a *Auth) setSystemInfoFlags(info *system.Info) {
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
}
// reconnect closes the current connection and creates a new one
// It checks if the brokenClient is still the current client before reconnecting
// to avoid multiple threads reconnecting unnecessarily
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
a.mutex.Lock()
defer a.mutex.Unlock()
// Double-check: if client has already been replaced by another thread, skip reconnection
if a.client != brokenClient {
log.Debugf("client already reconnected by another thread, skipping")
return nil
}
// Create new connection FIRST, before closing the old one
// This ensures a.client is never nil, preventing panics in other threads
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
// Keep the old client if reconnection fails
return err
}
// Close old connection AFTER new one is successfully created
oldClient := a.client
a.client = mgmClient
if oldClient != nil {
if err := oldClient.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
return nil
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
// Capture the client BEFORE the operation to ensure we track the correct client
a.mutex.RLock()
currentClient := a.client
a.mutex.RUnlock()
if currentClient == nil {
return status.Errorf(codes.Unavailable, "client is not initialized")
}
// Execute operation with the captured client
err := operation(currentClient)
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection using the client that was actually used
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors, don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// isPermissionDenied checks if the error is a PermissionDenied error.
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
func isPermissionDenied(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.PermissionDenied
}
func isLoginNeeded(err error) bool {
return isAuthenticationError(err)
}
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -26,12 +25,56 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,8 +12,6 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
return pkceFlowInfo, nil
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
return deviceFlowInfo, nil
}

View File

@@ -20,7 +20,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -35,17 +34,67 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
providerConfig PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
excludedRanges := getSystemExcludedPortRanges()
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -9,8 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
availablePort := 65432
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -24,10 +24,14 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -39,11 +43,13 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -52,19 +58,20 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan)
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
}
// RunOnAndroid with main logic on mobile system
@@ -85,7 +92,7 @@ func (c *ConnectClient) RunOnAndroid(
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) RunOniOS(
@@ -103,10 +110,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -162,6 +169,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
var path string
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
// On mobile, use the provided state file path directly
if !fileExists(mobileDependency.StateFilePath) {
if err := createFile(mobileDependency.StateFilePath); err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDependency.StateFilePath
} else {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -249,7 +283,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
relayURLs, token := parseRelayInfo(loginResp)
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -273,7 +307,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -283,6 +317,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
@@ -376,6 +419,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil
}
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
@@ -415,7 +471,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -450,7 +506,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
ProfileConfig: config,
}
if config.PreSharedKey != "" {

View File

@@ -27,8 +27,11 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -56,6 +59,8 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -109,6 +114,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -218,10 +226,11 @@ type BundleGenerator struct {
internalConfig *profilemanager.Config
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logFile string
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
@@ -230,7 +239,6 @@ type BundleGenerator struct {
type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
@@ -239,7 +247,9 @@ type GeneratorDependencies struct {
InternalConfig *profilemanager.Config
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogFile string
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -255,10 +265,11 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
internalConfig: deps.InternalConfig,
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logFile: deps.LogFile,
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
@@ -304,13 +315,6 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add status: %w", err)
}
if g.statusRecorder != nil {
status := g.statusRecorder.GetFullStatus()
seedFromStatus(g.anonymizer, &status)
} else {
log.Debugf("no status recorder available for seeding")
}
if err := g.addConfig(); err != nil {
log.Errorf("failed to add config to debug bundle: %v", err)
}
@@ -327,6 +331,14 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addCPUProfile(); err != nil {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -343,7 +355,7 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
@@ -354,6 +366,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
log.Errorf("failed to add updater logs: %v", err)
}
return nil
}
@@ -388,11 +404,30 @@ func (g *BundleGenerator) addReadme() error {
}
func (g *BundleGenerator) addStatus() error {
if status := g.clientStatus; status != "" {
statusReader := strings.NewReader(status)
if g.statusRecorder != nil {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
if g.refreshStatus != nil {
g.refreshStatus()
}
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
seedFromStatus(g.anonymizer, &fullStatus)
} else {
log.Debugf("no status recorder available for seeding")
}
return nil
}
@@ -522,6 +557,31 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addCPUProfile() error {
if len(g.cpuProfile) == 0 {
return nil
}
reader := bytes.NewReader(g.cpuProfile)
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
return fmt.Errorf("add CPU profile to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
@@ -630,6 +690,29 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
func (g *BundleGenerator) addUpdateLogs() error {
inst := installer.New()
logFiles := inst.LogFiles()
if len(logFiles) == 0 {
return nil
}
log.Infof("adding updater logs")
for _, logFile := range logFiles {
data, err := os.ReadFile(logFile)
if err != nil {
log.Warnf("failed to read update log file %s: %v", logFile, err)
continue
}
baseName := filepath.Base(logFile)
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
}
}
return nil
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()
@@ -662,14 +745,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
}
func (g *BundleGenerator) addLogfile() error {
if g.logFile == "" {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")
return nil
}
logDir := filepath.Dir(g.logFile)
logDir := filepath.Dir(g.logPath)
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}

View File

@@ -507,15 +507,13 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
if p.Base == expr.PayloadBaseNetworkHeader {
switch p.Offset {
case 12:
if p.Len == 4 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
switch p.Len {
case 4, 2:
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
case 16:
if p.Len == 4 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
switch p.Len {
case 4, 2:
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
}

View File

@@ -0,0 +1,101 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -1,4 +1,4 @@
package server
package debug
import (
"context"
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
fileContent := []byte("test file content")
err := os.WriteFile(file, fileContent, 0640)
require.NoError(t, err)
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
require.NoError(t, err)
id := getURLHash(testURL)
require.Contains(t, key, id+"/")

View File

@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
if peer.PresharedKey != [32]byte{} {
sb.WriteString(" preshared key: (hidden)\n")
}
}

View File

@@ -1,136 +0,0 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
if zone.NonAuthoritative {
continue
}
for _, record := range zone.Records {

View File

@@ -3,17 +3,21 @@ package dns
import (
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
)
const (
PriorityMgmtCache = 150
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityDNSRoute = 100
PriorityLocal = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100
@@ -43,7 +47,23 @@ type HandlerChain struct {
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
requestID string
shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true
return nil
}
w.response = m
return w.ResponseWriter.WriteMsg(m)
}
@@ -71,6 +92,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
c.mu.Lock()
defer c.mu.Unlock()
log.Warnf("[DNS-ROUTE] HandlerChain.AddHandler: pattern=%s priority=%d handler=%T", pattern, priority, handler)
pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
@@ -87,6 +110,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
matchSubdomains = matcher.MatchSubdomains()
}
log.Warnf("[DNS-ROUTE] HandlerChain.AddHandler: processed pattern=%s origPattern=%s wildcard=%v matchSubdomains=%v priority=%d",
pattern, origPattern, isWildcard, matchSubdomains, priority)
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
pattern, origPattern, isWildcard, matchSubdomains, priority)
@@ -101,6 +126,9 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
log.Warnf("[DNS-ROUTE] HandlerChain.AddHandler: total handlers now=%d", len(c.handlers))
c.logHandlers()
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -140,68 +168,114 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break
}
}
}
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := strings.ToLower(r.Question[0].Name)
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry)
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
if !c.isHandlerMatch(qname, entry) {
continue
}
// Only log for DNS route handlers to reduce noise
if entry.Priority == PriorityDNSRoute {
log.Warnf("[DNS-ROUTE] HandlerChain.ServeDNS: matched DNS route handler pattern=%s for domain=%s",
entry.OrigPattern, qname)
}
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
}
continue
}
c.logResponse(logger, chainWriter, qname, startTime)
return
}
// No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname)
logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":

View File

@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false,
shouldMatch: false,
},
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
}
for _, tt := range tests {

View File

@@ -9,8 +9,10 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -38,6 +40,9 @@ const (
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
}
func newHostManager() (*systemConfigurator, error) {
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false
inServerAddressesArray := false
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
}
}
}
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port
dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil
}
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {

View File

@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -1,30 +1,52 @@
package local
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
}
func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
}
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
@@ -48,60 +81,150 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
logger.Debug("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
result := d.lookupRecords(logger, question)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
}
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
if !exists && supportsWildcard(qType) {
testWild := transformDomainToWildcard(string(domainName))
_, exists = d.domains[domain.Domain(testWild)]
}
return exists
}
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock()
records, found := d.records[question]
usingWildcard := false
wildQuestion := transformToWildcard(question)
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
// If the domain exists with any record type, return NODATA instead of wildcard match.
if !found && supportsWildcard(question.Qtype) {
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
records, found = d.records[wildQuestion]
usingWildcard = found
}
}
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
cnameQuestion := dns.Question{
Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
}
return nil
return lookupResult{rcode: dns.RcodeNameError}
}
recordsCopy := slices.Clone(records)
@@ -110,29 +233,229 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
q := question
if usingWildcard {
q = wildQuestion
}
records = d.records[q]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
d.records[q] = records
}
d.mu.Unlock()
}
return recordsCopy
if usingWildcard {
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
func transformToWildcard(question dns.Question) dns.Question {
wildQuestion := question
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
return wildQuestion
}
func transformDomainToWildcard(domain string) string {
s := strings.Split(domain, ".")
s[0] = "*"
return strings.Join(s, ".")
}
func supportsWildcard(queryType uint16) bool {
return queryType != dns.TypeNS && queryType != dns.TypeSOA
}
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
records := make([]dns.RR, len(wildRecords))
for i, record := range wildRecords {
copiedRecord := dns.Copy(record)
copiedRecord.Header().Name = originalName
records[i] = copiedRecord
}
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
wildQuestion := transformToWildcard(cnameQuestion)
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
}
}
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
@@ -26,6 +27,11 @@ type Resolver struct {
mutex sync.RWMutex
}
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
return err
}
var aRecords, aaaaRecords []dns.RR
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {

View File

@@ -0,0 +1,197 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -80,6 +80,7 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -207,6 +208,7 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -218,6 +220,8 @@ func newDefaultServer(
// RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
log.Warnf("[DNS-ROUTE] DefaultServer.RegisterHandler: domains=%v priority=%d", domains.SafeString(), priority)
s.mux.Lock()
defer s.mux.Unlock()
@@ -228,10 +232,12 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
log.Warnf("[DNS-ROUTE] DefaultServer.RegisterHandler: extraDomains now has %d entries", len(s.extraDomains))
s.applyHostConfig()
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Warnf("[DNS-ROUTE] registerHandler: domains=%v priority=%d handler=%T", domains, priority, handler)
log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains)
for _, domain := range domains {
@@ -240,6 +246,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
continue
}
log.Warnf("[DNS-ROUTE] registerHandler: adding to handlerChain domain=%s priority=%d", domain, priority)
s.handlerChain.AddHandler(domain, handler, priority)
}
}
@@ -483,7 +490,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -496,8 +503,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.localResolver.Update(localRecords)
s.localResolver.Update(localZones)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -562,6 +568,7 @@ func (s *DefaultServer) enableDNS() error {
func (s *DefaultServer) applyHostConfig() {
// prevent reapplying config if we're shutting down
if s.ctx.Err() != nil {
log.Warnf("[DNS-ROUTE] applyHostConfig: skipped, context is done")
return
}
@@ -584,16 +591,42 @@ func (s *DefaultServer) applyHostConfig() {
}
}
log.Warnf("[DNS-ROUTE] applyHostConfig: routeAll=%v domains=%d extraDomains=%v",
config.RouteAll, len(config.Domains), maps.Keys(s.extraDomains))
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Warnf("[DNS-ROUTE] applyHostConfig: hash error, applying anyway: %s", err)
// Fall through to apply config anyway (fail-safe approach)
} else if s.currentConfigHash == hash {
log.Warnf("[DNS-ROUTE] applyHostConfig: skipped, no changes (hash=%d)", hash)
return
}
log.Warnf("[DNS-ROUTE] applyHostConfig: applying new config (oldHash=%d newHash=%d)", s.currentConfigHash, hash)
log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Warnf("[DNS-ROUTE] applyHostConfig: failed to apply: %v", err)
log.Errorf("failed to apply DNS host manager update: %v", err)
return
}
log.Warnf("[DNS-ROUTE] applyHostConfig: successfully applied")
// Only update hash if it was computed successfully and config was applied
if err == nil {
s.currentConfigHash = hash
}
s.registerFallback(config)
}
// registerFallback registers original nameservers as low-priority fallback handlers
// registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
@@ -602,6 +635,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return
}
@@ -609,9 +643,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,
@@ -636,9 +668,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord
var zones []nbdns.CustomZone
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -652,17 +684,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, localRecords, nil
return muxUpdates, zones, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
@@ -718,9 +753,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
domainGroup.domain,
@@ -901,9 +934,7 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil
}
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@@ -128,7 +133,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -180,8 +185,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -299,8 +304,8 @@ func TestUpdateDNSServer(t *testing.T) {
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -315,8 +320,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -385,7 +390,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,8 +515,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -1602,7 +1606,10 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 4,
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
// the domain remains in the config (ref count goes from 2 to 1), so the host
// config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 3,
},
{
name: "Config update with new domains after registration",
@@ -1657,7 +1664,10 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 3,
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
// it's removed from extraDomains but still remains in the config (from customZones),
// so the host config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 2,
},
{
name: "Register domain that is part of nameserver group",
@@ -2042,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -19,8 +18,10 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -70,6 +71,11 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
@@ -113,10 +119,10 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
u.prepareRequest(r)
@@ -125,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if u.tryUpstreamServers(w, r, logger) {
return
ok, failures := u.tryUpstreamServers(w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
if !ok {
u.writeErrorResponse(w, r, logger)
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
@@ -138,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -151,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
}
}
return false
return false, failures
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
@@ -173,53 +185,79 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
}()
if err != nil {
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
return u.handleUpstreamError(err, upstream, startTime)
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
return &upstreamFailure{upstream: upstream, reason: "no response"}
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
return nil
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
return &upstreamFailure{upstream: upstream, reason: err.Error()}
}
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
reason += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warn(timeoutMsg)
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
failedCount := len(failures)
failureSummary := formatFailures(failures)
if succeeded {
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
} else {
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
}
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
}
}
func formatFailures(failures []upstreamFailure) string {
parts := make([]string, 0, len(failures))
for _, f := range failures {
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
}
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
@@ -414,14 +452,53 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
return nil, err
}
return hex.EncodeToString(bytes)
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
}
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts

Some files were not shown because too many files have changed in this diff Show More