Compare commits

...

161 Commits

Author SHA1 Message Date
Diego Noguês
b016a1f0d0 feat: poc for token command on combined 2026-02-13 01:22:59 +01:00
Diego Noguês
c009055693 feat: adds netbird's proxy component to getting-started 2026-02-13 00:42:59 +01:00
Diego Noguês
14181c909c fix: remove duplicate import 2026-02-13 00:02:50 +01:00
mlsmaycon
a05dc3823d Merge branch 'main' into prototype/reverse-proxy
# Conflicts:
#	infrastructure_files/getting-started.sh
2026-02-12 19:27:12 +01:00
Misha Bragin
64b849c801 [self-hosted] add netbird server (#5232)
* Unified NetBird combined server (Management, Signal, Relay, STUN) as a single executable with richer YAML configuration, validation, and defaults.
  * Official Dockerfile/image for single-container deployment.
  * Optional in-process profiling endpoint for diagnostics.
  * Multiplexing to route HTTP/gRPC/WebSocket traffic via one port; runtime hooks to inject custom handlers.
* **Chores**
  * Updated deployment scripts, compose files, and reverse-proxy templates to target the combined server; added example configs and getting-started updates.
2026-02-12 19:24:43 +01:00
Diego Noguês
7d19bdf085 feat: adding traefik + nb's reverse proxy (#5303)
* feat: adding traefik and proxy component to getting-started

* feat: adding traefik and proxy component to getting-started

* feat: adding IPAM settings to docker compose and setting static ip to traefik

* fix: remove change to peers group all

* feat: switch to labels for traefik instead of static conf files

* feat: adding traefik and proxy component to getting-started

* feat: adding IPAM settings to docker compose and setting static ip to traefik

* fix: remove change to peers group all

* feat: switch to labels for traefik instead of static conf files

* chore: remove unnecessary comment

* chore: build

* chore: switching env var for NB_PROXY_DOMAIN
2026-02-12 19:12:20 +01:00
Diego Noguês
a1b048f2ad feat: adding traefik + nb reverse proxy 2026-02-12 18:43:35 +01:00
mlsmaycon
0bd227196e fix integration tests 2026-02-12 18:22:41 +01:00
Viktor Liu
eea7687ddf Fix lint and failing tests 2026-02-12 18:19:13 +01:00
mlsmaycon
57d3ee5aac optimize the DeriveClusterFromDomain function
1. validate domain only for proxy urls
2. use registered target cluster for custom domain extraction
2026-02-12 17:10:32 +01:00
pascal
cfdfdecc14 return error if unable to derive cluster on service creation 2026-02-12 16:57:16 +01:00
mlsmaycon
ac995bae6d rename url flag to domain and update validation 2026-02-12 16:28:29 +01:00
Alisdair MacLeod
41a5509ce0 fix nil pointer error in roundtripper 2026-02-12 15:19:19 +00:00
pascal
db5e26db94 rename domain type 2026-02-12 16:15:02 +01:00
Viktor Liu
fe975fb834 Fix missing lang attribute 2026-02-12 23:03:50 +08:00
Viktor Liu
e368d2995b Fix test 2026-02-12 22:57:28 +08:00
Viktor Liu
a3241d8376 Fix swallowed response codes 2026-02-12 22:54:17 +08:00
Alisdair MacLeod
6dfc5772ba fix nil pointer error in roundtripper 2026-02-12 14:44:07 +00:00
Viktor Liu
f70925178c Handle TCP port reuse for TIME-WAIT connections 2026-02-12 22:06:29 +08:00
Viktor Liu
9554934b92 Validate trusted proxies in OAuth callback getClientIP 2026-02-12 22:06:29 +08:00
Viktor Liu
7fdb824a37 Remove write permissions from /var/lib/netbird in proxy Dockerfile 2026-02-12 22:06:29 +08:00
Viktor Liu
412407adc0 Add .dockerignore to exclude sensitive files from build context 2026-02-12 22:06:29 +08:00
Viktor Liu
e0874d7de7 Add noopener to window.open in ErrorPage 2026-02-12 22:06:29 +08:00
pascal
8df1536cbb Merge branch 'main' into prototype/reverse-proxy 2026-02-12 15:05:14 +01:00
pascal
fcbacc62ec clear userID from access logs if not oidc 2026-02-12 14:50:35 +01:00
pascal
ee2ae45653 add permissions validation to domain manager 2026-02-12 14:31:23 +01:00
pascal
6f2f0f9ae4 exclude proxy peers on peers api 2026-02-12 13:49:05 +01:00
Alisdair MacLeod
c37ebc6fb3 add more metrics, improve metrics, reduce metrics impact on other packages 2026-02-12 12:36:54 +00:00
Viktor Liu
23abb5743c Treated tombstoned conns as new 2026-02-12 20:11:12 +08:00
Viktor Liu
b87aa0bc15 Address linter issues 2026-02-12 18:41:20 +08:00
Maycon Santos
69d4b5d821 [misc] Update sign pipeline version (#5296) 2026-02-12 11:31:49 +01:00
Viktor Liu
f1a65d732d Add proxy to license boundary check 2026-02-12 18:31:18 +08:00
Viktor Liu
a3c0ea3e71 Add proxy unit test workflow 2026-02-12 18:31:18 +08:00
Viktor Liu
abaf061c2a Skip nil client for health 2026-02-12 18:31:18 +08:00
pascal
e531fb54b1 ignore error 2026-02-12 11:20:22 +01:00
mlsmaycon
5fcfed5b16 add proxy tests 2026-02-12 11:19:10 +01:00
pascal
5f43449f67 move linter exceptions 2026-02-12 10:45:21 +01:00
mlsmaycon
6796601aa6 Generate a random nonce to ensure each OIDC request gets a unique state 2026-02-12 10:45:13 +01:00
pascal
1fc25c301b move linter exceptions 2026-02-12 10:11:49 +01:00
Viktor Liu
08ae281b2d Fix network monitor restarting the client in netstack mode 2026-02-12 16:48:31 +08:00
Viktor Liu
3dfa97dcbd [client] Fix stale entries in nftables with no handle (#5272) 2026-02-12 09:15:57 +01:00
Viktor Liu
1ddc9ce2bf [client] Fix nil pointer panic in device and engine code (#5287) 2026-02-12 09:15:42 +01:00
Viktor Liu
bd47f44c63 Preload services targets 2026-02-12 16:04:55 +08:00
Viktor Liu
381260911b Create unique token per proxy 2026-02-12 15:48:35 +08:00
Viktor Liu
38db42e7d6 Fix initial sync complete on empty service list 2026-02-12 15:48:35 +08:00
Viktor Liu
5d606d909d Add TTL-based expiry and cleanup for PKCE verifiers to prevent unbounded memory growth 2026-02-12 15:12:41 +08:00
Viktor Liu
d689718b50 Improve logging and error handling 2026-02-12 15:12:41 +08:00
pascal
54a73c6649 move linter exceptions 2026-02-12 02:10:00 +01:00
pascal
418377842e fix tests 2026-02-12 02:00:22 +01:00
pascal
15ef56e03d fix typos 2026-02-12 01:54:14 +01:00
pascal
917035f8e8 fix tests 2026-02-12 01:52:30 +01:00
pascal
963e3f5457 fix linter issues 2026-02-12 01:15:36 +01:00
pascal
e20b969188 fix linter issues 2026-02-12 01:02:13 +01:00
pascal
1c7059ee67 fix some tests 2026-02-12 00:16:33 +01:00
pascal
22a3365658 fix rename errors and tests 2026-02-11 22:34:50 +01:00
Maycon Santos
2de1949018 [client] Check if login is required on foreground mode (#5295) 2026-02-11 21:42:36 +01:00
pascal
08ab1e3478 rename reverse proxy to services 2026-02-11 21:39:51 +01:00
pascal
ebb1f4007d add id to request log search 2026-02-11 19:25:23 +01:00
pascal
acb53ece93 Merge branch 'prototype/reverse-proxy-logs-pagination' into prototype/reverse-proxy 2026-02-11 18:51:28 +01:00
pascal
e020950cfd concat host and path for search and add a status filter 2026-02-11 17:54:29 +01:00
pascal
9dba262a20 add index to access log entries 2026-02-11 17:07:15 +01:00
pascal
5bcdf36377 fix source_ip 2026-02-11 16:50:27 +01:00
pascal
1ffe8deb10 add general search filter 2026-02-11 16:38:31 +01:00
pascal
d069145bd1 add more filters 2026-02-11 16:23:52 +01:00
Alisdair MacLeod
f3493ee042 add basic metrics for stress testing 2026-02-11 14:56:39 +00:00
pascal
bf48044e5c push filter files 2026-02-11 14:52:44 +01:00
pascal
fb4cc37a4a add pagination for access logs 2026-02-11 14:41:52 +01:00
pascal
55b8d89a79 add rate limiting for callback endpoint 2026-02-11 13:42:54 +01:00
pascal
6968a32a5a move to argon2id 2026-02-11 13:26:40 +01:00
pascal
cfe6753349 hash pin and password 2026-02-11 11:48:15 +01:00
Alisdair MacLeod
5ae15b3af3 add hotpath proxy and roundtripper benchmarks 2026-02-11 09:47:40 +00:00
pascal
b79adb706c add services to permissions list 2026-02-11 10:38:20 +01:00
mlsmaycon
f22497d5da remove query parameters on refresh 2026-02-10 21:53:18 +01:00
mlsmaycon
95d672c9df fix: capture auth method in access logs for failed authentication
- Add wasCredentialSubmitted helper to detect when credentials were
  submitted but authentication failed
- Set auth method in CapturedData when wrong PIN/password is entered
- Set auth method for OAuth callback errors and token validation errors
- Add tests for failed auth method capture
2026-02-10 21:33:15 +01:00
mlsmaycon
7d08a609e6 fix: capture account/service/user IDs in access logs for auth requests
- Add accountID and serviceID to auth middleware DomainConfig
- Set account/service IDs in CapturedData when domain is matched
- Update AddDomain to accept accountID and serviceID parameters
- Skip access logging for internal proxy assets (/__netbird__/*)
- Return validationResult struct from validateSessionToken to preserve
  user ID even when access is denied
- Capture user ID and auth method in access logs for denied requests
2026-02-10 20:55:07 +01:00
mlsmaycon
eea6120cd0 refactor: add ValidateSession gRPC and streamline test setup
- Add ValidateSession gRPC method for proxy-side user validation
- Move group access validation from REST callback to gRPC layer
- Capture user info in access logs via CapturedData mutable pointer
- Create validate_session_test.go for gRPC validation tests
- Simplify auth_callback_integration_test.go to create accounts
  programmatically instead of using SQL file
- SQL test data file now only used by validate_session_test.go
2026-02-10 20:31:03 +01:00
Vlad
fc88399c23 [management] fixed ischild check (#5279) 2026-02-10 20:31:15 +03:00
pascal
0cb02bd906 fix path handling + extract targets to separate table + guard resource/peer deletion 2026-02-10 17:12:34 +01:00
mlsmaycon
08d3867f41 update error page 2026-02-10 16:54:05 +01:00
mlsmaycon
b16d63643c Add group-based access control for SSO reverse proxy authentication
Implement user group validation during OAuth callback to ensure users
belong to allowed distribution groups before granting access to reverse
proxies. This provides account isolation and fine-grained access control.

Key changes:
- Add ValidateUserGroupAccess to ProxyServiceServer for group membership checks
- Redirect denied users to error page with access_denied parameter
- Handle OAuth error responses in proxy middleware
- Add comprehensive integration tests for auth callback flow
2026-02-10 16:25:00 +01:00
Eduard Gert
940d01bdea Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-02-10 14:39:48 +01:00
Eduard Gert
ba9158d159 Remove peer card from proxy error page 2026-02-10 14:39:25 +01:00
pascal
ca9a7e11ef continue on host lookup failure 2026-02-10 14:38:15 +01:00
pascal
a803f47685 add network map support for clustering 2026-02-10 14:29:20 +01:00
Viktor Liu
79fed32f01 Add wg port configuration 2026-02-10 19:55:48 +08:00
Viktor Liu
6b00bb0a66 Strip session_token on redirect 2026-02-10 18:27:31 +08:00
mlsmaycon
e2adef1eea add back notBefore and now to cert log 2026-02-09 20:37:20 +01:00
pascal
9e5fa11792 handle multiple path 2026-02-09 19:25:30 +01:00
pascal
1ff75acb31 handle default ports 2026-02-09 19:23:39 +01:00
pascal
1754160686 handle default ports 2026-02-09 19:21:43 +01:00
pascal
423f6266fb handle default ports 2026-02-09 18:18:53 +01:00
pascal
16d1b4a14a handle default ports 2026-02-09 18:15:26 +01:00
pascal
7c14056faf fix resource lookup 2026-02-09 17:58:28 +01:00
pascal
62e37dc2e2 fix host resolution 2026-02-09 17:56:38 +01:00
pascal
6a08695ee8 Merge branch 'main' into prototype/reverse-proxy 2026-02-09 17:16:00 +01:00
pascal
9a67a8e427 send updates on changes 2026-02-09 17:06:04 +01:00
Viktor Liu
73aa0785ba Add cert health info to checks 2026-02-09 22:55:12 +08:00
Viktor Liu
53c1016a8e Add graceful shutdown for Kubernetes 2026-02-09 22:55:12 +08:00
Viktor Liu
fd442138e6 Add cert hot reload and cert file locking
Adds file-watching certificate hot reload, cross-replica ACME
certificate lock coordination via flock (Unix) and Kubernetes lease
objects.
2026-02-09 22:55:12 +08:00
pascal
be5f30225a fix embedded exception 2026-02-09 15:28:48 +01:00
pascal
7467e9fb8c use portrange 2026-02-09 14:46:23 +01:00
pascal
2390c2e46e change network map calc to inject proxy policies 2026-02-09 14:41:22 +01:00
Zoltan Papp
6981fdce7e [client] Fix race condition and ensure correct message ordering in Relay (#5265)
* Fix race condition and ensure correct message ordering in
connection establishment

Reorder operations in OpenConn to register the connection before
waiting for peer availability. This ensures:

- Connection is ready to receive messages before peer subscription
completes
- Transport messages and onconnected events maintain proper ordering
- No messages are lost during the connection establishment window
- Concurrent OpenConn calls cannot create duplicate connections

If peer availability check fails, the pre-registered connection is
properly cleaned up.

* Handle service shutdown during relay connection initialization

Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors.
2026-02-09 11:34:24 +01:00
Viktor Liu
08403f64aa [client] Add env var to skip DNS probing (#5270) 2026-02-09 11:09:11 +01:00
Viktor Liu
391221a986 [client] Fix uspfilter duplicate firewall rules (#5269) 2026-02-09 10:14:02 +01:00
mlsmaycon
778c223176 fix api handler path 2026-02-09 02:30:06 +01:00
mlsmaycon
36cd0dd85c temp fix import cycle 2026-02-09 02:10:21 +01:00
mlsmaycon
09a1d5a02d rename endpoint 2026-02-09 01:48:51 +01:00
mlsmaycon
7c996ac9b5 add AuthCallbackURL 2026-02-09 01:18:49 +01:00
mlsmaycon
cf9fd5d960 add AuthClientID 2026-02-08 19:41:52 +01:00
mlsmaycon
1c5ab7cb8f add logger support to acme manager 2026-02-08 19:11:27 +01:00
Viktor Liu
aaad3b25a7 Increase client startup timeout
The client has to start mgmt, signal, relay and wireguard/netstack.
If this times out, the client shuts down and never manages to start.
2026-02-09 02:02:18 +08:00
Viktor Liu
9904235a2f Improve embed client error detection and reporting 2026-02-09 01:51:53 +08:00
Viktor Liu
780e9f57a5 Improve mgmt backoff 2026-02-09 01:51:53 +08:00
mlsmaycon
a8db73285b add issued time log and CT timestamp logs 2026-02-08 18:13:50 +01:00
Viktor Liu
3b43c00d12 Use unique static path for auth assets to avoid collision with routes 2026-02-09 01:10:50 +08:00
Viktor Liu
2f390e1794 Conflate default ports 2026-02-09 00:57:08 +08:00
Viktor Liu
3630ebb3ae Add option to rewrite redirects 2026-02-09 00:44:47 +08:00
Viktor Liu
260c46df04 Fix broken auth redirect 2026-02-09 00:02:54 +08:00
Viktor Liu
7f11e3205d Validate target id 2026-02-08 23:44:31 +08:00
Viktor Liu
1c8f92a96f Fix management nil pointer 2026-02-08 23:29:16 +08:00
Viktor Liu
7b6294b624 Refuse to service a service if auth setup failed 2026-02-08 23:24:43 +08:00
Viktor Liu
156d0b1fef Fix duplicate path 2026-02-08 21:41:32 +08:00
Viktor Liu
2cf00dba58 Fix missing route 2026-02-08 21:36:55 +08:00
Viktor Liu
d2a7f3ae36 Fix pass host header 2026-02-08 21:33:48 +08:00
Viktor Liu
6a64d4e4dd Remove test deployment specs 2026-02-08 21:13:22 +08:00
Viktor Liu
51e63c246b Add health status to debug 2026-02-08 21:04:46 +08:00
mlsmaycon
99e6b1eda4 attempt to trigger ssl before first request
1. When AddDomain() is called (when proxy receives a new mapping), it now spawns a goroutine to prefetch the certificate
  2. prefetchCertificate() creates a synthetic tls.ClientHelloInfo and calls GetCertificate() to trigger the ACME flow
  3. The certificate is cached by autocert.DirCache, so subsequent real requests will use the cached cert
  4. If the cert is already cached (e.g., proxy restart), GetCertificate just returns it without making ACME requests
2026-02-08 10:59:36 +01:00
Viktor Liu
dc26a5a436 Merge branch 'main' into prototype/reverse-proxy 2026-02-08 17:50:16 +08:00
Viktor Liu
3883b2fb41 Fix netbird_test.go 2026-02-08 17:49:03 +08:00
Viktor Liu
ed58659a01 Set forwarded headers from trusted proxies only 2026-02-08 17:49:03 +08:00
Viktor Liu
5190923c70 Improve logging requests 2026-02-08 17:49:03 +08:00
Viktor Liu
7c647dd160 Add peer firewall to the receiving peer 2026-02-08 17:49:03 +08:00
Viktor Liu
07e59b2708 Add reverse proxy header security and forwarding
- Rewrite Host header to backend target (configurable via pass_host_header per mapping)
- Strip and set X-Forwarded-For/X-Real-IP from direct connection (trust boundary)
- Set X-Forwarded-Host and X-Forwarded-Proto headers
- Strip nb_session cookie and session_token query param before forwarding
- Add --forwarded-proto flag (auto/http/https) for proto detection
- Fix OIDC redirect hardcoded https scheme
- Add pass_host_header to proto, API, and management model
2026-02-08 15:00:35 +08:00
Viktor Liu
0a3a9f977d Add proxy <-> management authentication 2026-02-08 14:33:27 +08:00
mlsmaycon
2f263bf7e6 fix cluster logic for domains and reverse proxy 2026-02-07 11:43:01 +01:00
mlsmaycon
f65f4fc280 fix some conflicts regression 2026-02-06 20:39:17 +01:00
Zoltan Papp
7bc85107eb Adds timing measurement to handleSync to help diagnose sync performance issues (#5228) 2026-02-06 19:50:48 +01:00
Zoltan Papp
3be16d19a0 [management] Feature/grpc debounce msgtype (#5239)
* Add gRPC update debouncing mechanism

Implements backpressure handling for peer network map updates to
efficiently handle rapid changes. First update is sent immediately,
subsequent rapid updates are coalesced, ensuring only the latest
update is sent after a 1-second quiet period.

* Enhance unit test to verify peer count synchronization with debouncing and timeout handling

* Debounce based on type

* Refactor test to validate timer restart after pending update dispatch

* Simplify timer reset for Go 1.23+ automatic channel draining

Remove manual channel drain in resetTimer() since Go 1.23+ automatically
drains the timer channel when Stop() returns false, making the
select-case pattern unnecessary.
2026-02-06 19:47:38 +01:00
Vlad
af8f730bda [management] check stream start time for connecting peer (#5267) 2026-02-06 18:00:43 +01:00
pascal
adbd7ab4c3 send account updates on proxy change 2026-02-06 17:03:18 +01:00
pascal
0419834482 add routed exposed services support in nmap 2026-02-06 15:42:13 +01:00
eyJhb
c3f176f348 [client] Fix wrong URL being logged for DefaultAdminURL (#5252)
- DefaultManagementURL was being logged instead of DefaultAdminURL
2026-02-06 11:23:36 +01:00
Viktor Liu
0119f3e9f4 [client] Fix netstack detection and add wireguard port option (#5251)
- Add WireguardPort option to embed.Options for custom port configuration
- Fix KernelInterface detection to account for netstack mode
- Skip SSH config updates when running in netstack mode
- Skip interface removal wait when running in netstack mode
- Use BindListener for netstack to avoid port conflicts on same host
2026-02-06 10:03:01 +01:00
pascal
f797d2d9cb fix cert dir name in docker file 2026-02-05 15:46:07 +01:00
pascal
5ae7efe8f7 Merge remote-tracking branch 'origin/prototype/reverse-proxy' into prototype/reverse-proxy 2026-02-05 15:22:39 +01:00
pascal
d6e35bd0fe fix merge conflicts 2026-02-05 15:22:23 +01:00
pascal
0e00f1c8f7 Merge remote-tracking branch 'origin/prototype/reverse-proxy-clusters' into prototype/reverse-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/manager/manager.go
#	management/internals/modules/reverseproxy/reverseproxy.go
#	management/internals/server/modules.go
#	management/internals/shared/grpc/proxy.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-02-05 15:19:57 +01:00
Viktor Liu
1b96648d4d [client] Always log dns forwader responses (#5262) 2026-02-05 14:34:35 +01:00
Zoltan Papp
d2f9653cea Fix nil pointer panic in ICE agent during sleep/wake cycles (#5261)
Add defensive nil checks in ThreadSafeAgent.Close() to prevent panic
when agent field is nil. This can occur during Windows suspend/resume
when network interfaces are disrupted or the pion/ice library returns
nil without error.

Also capture agent pointer in local variable before goroutine execution
to prevent race conditions.

Fixes service crashes on laptop wake-up.
2026-02-05 12:06:28 +01:00
mlsmaycon
5ccce1ab3f add debug logging for proxy connections and domain resolution
- Log proxy address and cluster info when proxy connects
  - Log connected proxy URLs when GetConnectedProxyURLs is called
  - Log proxy allow list when GetDomains is called
  - Helps debug issues with free domains not appearing in API response
2026-02-05 02:18:38 +01:00
Zoltan Papp
194a986926 Cache the result of wgInterface.ToInterface() using sync.Once (#5256)
Avoid repeated conversions during route setup. The toInterface helper ensures
the conversion happens only once regardless of how many routes are added
or removed.
2026-02-04 22:22:37 +01:00
Viktor Liu
f7732557fa [client] Add missing bsd flags in debug bundle (#5254) 2026-02-04 18:07:27 +01:00
Vlad
d488f58311 [management] fix set disconnected status for connected peer (#5247) 2026-02-04 11:44:46 +01:00
mlsmaycon
b02982f6b1 add logs 2026-02-04 03:14:26 +01:00
mlsmaycon
4d89ae27ef add clusters logic 2026-02-04 02:16:57 +01:00
Pascal Fischer
6fdc00ff41 [management] adding account id validation to accessible peers handler (#5246) 2026-02-03 17:30:02 +01:00
Misha Bragin
b20d484972 [docs] Add selfhosting video (#5235) 2026-02-01 16:06:36 +01:00
Vlad
8931293343 [management] run cancelPeerRoutinesWithoutLock in sync (#5234) 2026-02-01 15:44:27 +01:00
Vlad
7b830d8f72 disable sync lim (#5233) 2026-02-01 14:37:00 +01:00
Misha Bragin
3a0cf230a1 Disable local users for a smooth single-idp mode (#5226)
Add LocalAuthDisabled option to embedded IdP configuration

This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external
identity providers (Google, OIDC, etc.).

This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page.

Key changes:

Added LocalAuthDisabled field to EmbeddedIdPConfig
Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth
Validation prevents disabling local auth if no external connectors are configured
Existing local users are preserved when disabled and can login again when re-enabled
Operations are idempotent (disabling already disabled is a no-op)
2026-02-01 14:26:22 +01:00
227 changed files with 20391 additions and 3183 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)

View File

@@ -46,6 +46,5 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -144,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -204,7 +204,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
'
test_relay:
@@ -261,6 +261,53 @@ jobs:
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -20,7 +20,7 @@ jobs:
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:
fail-fast: false

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.0"
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
.run
*.iml
dist/
!proxy/web/dist/
bin/
.env
conf.json

View File

@@ -106,6 +106,26 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -520,6 +540,55 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -598,6 +667,18 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -675,6 +756,19 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features

View File

@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
}
defer authClient.Close()
needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("check login required: %v", err)
}
jwtToken := ""

View File

@@ -31,6 +31,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string

View File

@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("list rules: %w", err)
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
newRules[string(rule.UserData)] = rule
}
}
}
return nil
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -719,3 +720,137 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -3,12 +3,6 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,12 +1,9 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()
if !isWindowsFirewallReachable() {
return nil

View File

@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
if exists && !conn.IsSupersededBy(flags) {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
}
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsTombstone() {
if !exists || conn.IsSupersededBy(flags) {
return false
}

View File

@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
})
}
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -12,11 +13,13 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -24,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -89,6 +93,7 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleID := uuid.New().String()
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
id: string(ruleKey),
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleID := rule.ID()
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID)
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {

View File

@@ -0,0 +1,376 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -5,6 +5,8 @@ import (
"context"
"fmt"
"io"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -16,9 +18,18 @@ const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
defaultLogChanSize = 1000
)
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32
const (
@@ -69,7 +80,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
msgChannel: make(chan logMessage, getLogChannelSize()),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {

View File

@@ -29,8 +29,9 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
}()
return nsTunDev, tunNet, nil
return t.tundev, tunNet, nil
}
func (t *NetStackTun) Close() error {

View File

@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,7 +6,9 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
@@ -27,6 +29,8 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)

View File

@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
if len(query.Question) == 0 {
return nil
return
}
question := query.Question[0]
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
qname := strings.ToLower(question.Name)
domain := strings.ToLower(question.Name)
logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
f.writeResponse(logger, w, resp, qname, startTime)
return
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(qname, question.Qtype, result.IPs)
return resp
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
f.handleDNSQuery(logger, w, query, startTime)
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
startTime time.Time,
) {
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
return
}
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
return
}
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
return
}
}
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
}
// getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
// Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
} else {
require.NotNil(t, resp, "Expected response")
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
// Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2)
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
assert.Empty(t, resp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
@@ -543,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() {
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart()
} else if err != nil {
@@ -828,6 +830,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1918,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
log.Infof("Network monitor is disabled, not starting")
return
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
@@ -38,11 +37,6 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()

View File

@@ -2,6 +2,7 @@ package ice
import (
"context"
"fmt"
"sync"
"time"
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {

View File

@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
if w.agent != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
sessionID, err := NewICESessionID()

View File

@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
log.Infof("using default Admin URL %s", DefaultAdminURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err

View File

@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
},
func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
},
)

View File

@@ -4,16 +4,17 @@ package systemops
import (
"strings"
"syscall"
"golang.org/x/sys/unix"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
if routeMessageFlags&unix.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
return true
}
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&syscall.RTF_UP != 0 {
if flags&unix.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&syscall.RTF_GATEWAY != 0 {
if flags&unix.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&syscall.RTF_HOST != 0 {
if flags&unix.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&syscall.RTF_REJECT != 0 {
if flags&unix.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&syscall.RTF_DYNAMIC != 0 {
if flags&unix.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&syscall.RTF_MODIFIED != 0 {
if flags&unix.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&syscall.RTF_STATIC != 0 {
if flags&unix.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&syscall.RTF_LLINFO != 0 {
if flags&unix.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&syscall.RTF_LOCAL != 0 {
if flags&unix.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&syscall.RTF_BLACKHOLE != 0 {
if flags&unix.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
if flags&syscall.RTF_CLONING != 0 {
if flags&unix.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C")
}
if flags&syscall.RTF_WASCLONED != 0 {
if flags&unix.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W")
}
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -4,17 +4,18 @@ package systemops
import (
"strings"
"syscall"
"golang.org/x/sys/unix"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
if routeMessageFlags&unix.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
return true
}
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&syscall.RTF_UP != 0 {
if flags&unix.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&syscall.RTF_GATEWAY != 0 {
if flags&unix.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&syscall.RTF_HOST != 0 {
if flags&unix.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&syscall.RTF_REJECT != 0 {
if flags&unix.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&syscall.RTF_DYNAMIC != 0 {
if flags&unix.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&syscall.RTF_MODIFIED != 0 {
if flags&unix.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&syscall.RTF_STATIC != 0 {
if flags&unix.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&syscall.RTF_LLINFO != 0 {
if flags&unix.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&syscall.RTF_LOCAL != 0 {
if flags&unix.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&syscall.RTF_BLACKHOLE != 0 {
if flags&unix.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

5
combined/Dockerfile Normal file
View File

@@ -0,0 +1,5 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server

715
combined/cmd/config.go Normal file
View File

@@ -0,0 +1,715 @@
package cmd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"path"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
// CombinedConfig is the root configuration for the combined server.
// The combined server is primarily a Management server with optional embedded
// Signal, Relay, and STUN services.
//
// Architecture:
// - Management: Always runs locally (this IS the management server)
// - Signal: Runs locally by default; disabled if server.signalUri is set
// - Relay: Runs locally by default; disabled if server.relays is set
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
//
// All user-facing settings are under "server". The relay/signal/management
// fields are internal and populated automatically from server settings.
type CombinedConfig struct {
Server ServerConfig `yaml:"server"`
// Internal configs - populated from Server settings, not user-configurable
Relay RelayConfig `yaml:"-"`
Signal SignalConfig `yaml:"-"`
Management ManagementConfig `yaml:"-"`
}
// ServerConfig contains server-wide settings
// In simplified mode, this contains all configuration
type ServerConfig struct {
ListenAddress string `yaml:"listenAddress"`
MetricsPort int `yaml:"metricsPort"`
HealthcheckAddress string `yaml:"healthcheckAddress"`
LogLevel string `yaml:"logLevel"`
LogFile string `yaml:"logFile"`
TLS TLSConfig `yaml:"tls"`
// Simplified config fields (used when relay/signal/management sections are omitted)
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
DataDir string `yaml:"dataDir"` // Data directory for all services
// External service overrides (simplified mode)
// When these are set, the corresponding local service is NOT started
// and these values are used for client configuration instead
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
// Management settings (simplified mode)
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// TLSConfig contains TLS/HTTPS settings
type TLSConfig struct {
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
}
// LetsEncryptConfig contains Let's Encrypt settings
type LetsEncryptConfig struct {
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"dataDir"`
Domains []string `yaml:"domains"`
Email string `yaml:"email"`
AWSRoute53 bool `yaml:"awsRoute53"`
}
// RelayConfig contains relay service settings
type RelayConfig struct {
Enabled bool `yaml:"enabled"`
ExposedAddress string `yaml:"exposedAddress"`
AuthSecret string `yaml:"authSecret"`
LogLevel string `yaml:"logLevel"`
Stun StunConfig `yaml:"stun"`
}
// StunConfig contains embedded STUN service settings
type StunConfig struct {
Enabled bool `yaml:"enabled"`
Ports []int `yaml:"ports"`
LogLevel string `yaml:"logLevel"`
}
// SignalConfig contains signal service settings
type SignalConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
}
// ManagementConfig contains management service settings
type ManagementConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
DataDir string `yaml:"dataDir"`
DnsDomain string `yaml:"dnsDomain"`
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
Auth AuthConfig `yaml:"auth"`
Stuns []HostConfig `yaml:"stuns"`
Relays RelaysConfig `yaml:"relays"`
SignalURI string `yaml:"signalUri"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
}
// AuthStorageConfig contains auth storage settings
type AuthStorageConfig struct {
Type string `yaml:"type"`
File string `yaml:"file"`
}
// AuthOwnerConfig contains initial admin user settings
type AuthOwnerConfig struct {
Email string `yaml:"email"`
Password string `yaml:"password"`
}
// HostConfig represents a STUN/TURN/Signal host
type HostConfig struct {
URI string `yaml:"uri"`
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// RelaysConfig contains external relay server settings for clients
type RelaysConfig struct {
Addresses []string `yaml:"addresses"`
CredentialsTTL string `yaml:"credentialsTTL"`
Secret string `yaml:"secret"`
}
// StoreConfig contains database settings
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
}
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
}
// DefaultConfig returns a CombinedConfig with default values
func DefaultConfig() *CombinedConfig {
return &CombinedConfig{
Server: ServerConfig{
ListenAddress: ":443",
MetricsPort: 9090,
HealthcheckAddress: ":9000",
LogLevel: "info",
LogFile: "console",
StunPorts: []int{3478},
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Store: StoreConfig{
Engine: "sqlite",
},
},
Relay: RelayConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
Stun: StunConfig{
Enabled: false,
Ports: []int{3478},
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
},
Signal: SignalConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
Management: ManagementConfig{
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Relays: RelaysConfig{
CredentialsTTL: "12h",
},
Store: StoreConfig{
Engine: "sqlite",
},
},
}
}
// hasRequiredSettings returns true if the configuration has the required server settings
func (c *CombinedConfig) hasRequiredSettings() bool {
return c.Server.ExposedAddress != ""
}
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
// Returns: protocol ("https" or "http"), hostname only, and host:port
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
// Default to https if no protocol specified
protocol = "https"
hostPort = exposedAddress
// Check for protocol prefix
if strings.HasPrefix(exposedAddress, "https://") {
protocol = "https"
hostPort = strings.TrimPrefix(exposedAddress, "https://")
} else if strings.HasPrefix(exposedAddress, "http://") {
protocol = "http"
hostPort = strings.TrimPrefix(exposedAddress, "http://")
}
// Extract hostname (without port)
hostname = hostPort
if host, _, err := net.SplitHostPort(hostPort); err == nil {
hostname = host
}
return protocol, hostname, hostPort
}
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
// overrides are configured (server.signalUri, server.relays, server.stuns).
func (c *CombinedConfig) ApplySimplifiedDefaults() {
if !c.hasRequiredSettings() {
return
}
// Parse exposed address to extract protocol and hostname
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
// Check for external service overrides
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
hasExternalSignal := c.Server.SignalURI != ""
hasExternalStuns := len(c.Server.Stuns) > 0
// Default stunPorts to [3478] if not specified and no external STUN
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
c.Server.StunPorts = []int{3478}
}
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
c.applySignalDefaults(hasExternalSignal)
c.applyManagementDefaults(exposedHost)
// Auto-configure client settings (stuns, relays, signalUri)
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
return
}
c.Relay.Enabled = true
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
c.Relay.AuthSecret = c.Server.AuthSecret
if c.Relay.LogLevel == "" {
c.Relay.LogLevel = c.Server.LogLevel
}
// Enable local STUN only if no external STUN servers and stunPorts are configured
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
c.Relay.Stun.Enabled = true
c.Relay.Stun.Ports = c.Server.StunPorts
if c.Relay.Stun.LogLevel == "" {
c.Relay.Stun.LogLevel = c.Server.LogLevel
}
}
}
// applySignalDefaults configures the signal service if no external signal is configured.
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
if hasExternalSignal {
return
}
c.Signal.Enabled = true
if c.Signal.LogLevel == "" {
c.Signal.LogLevel = c.Server.LogLevel
}
}
// applyManagementDefaults configures the management service (always enabled).
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
c.Management.Enabled = true
if c.Management.LogLevel == "" {
c.Management.LogLevel = c.Server.LogLevel
}
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
c.Management.DnsDomain = exposedHost
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
// Copy auth config from server if management auth issuer is not set
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
c.Management.Auth = c.Server.Auth
}
// Copy store config from server if not set
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" {
c.Management.Store = c.Server.Store
}
}
// Copy reverse proxy config from server
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
c.Management.ReverseProxy = c.Server.ReverseProxy
}
}
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
// External overrides from server config take precedence over auto-generated values
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
// Determine relay protocol from exposed protocol
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
// Configure STUN servers for clients
if hasExternalStuns {
// Use external STUN servers from server config
c.Management.Stuns = c.Server.Stuns
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
})
}
}
// Configure relay for clients
if hasExternalRelay {
// Use external relay config from server
c.Management.Relays = c.Server.Relays
} else if len(c.Management.Relays.Addresses) == 0 {
// Auto-configure local relay
c.Management.Relays.Addresses = []string{
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
}
}
if c.Management.Relays.Secret == "" {
c.Management.Relays.Secret = c.Server.AuthSecret
}
if c.Management.Relays.CredentialsTTL == "" {
c.Management.Relays.CredentialsTTL = "12h"
}
// Configure signal for clients
if hasExternalSignal {
// Use external signal URI from server config
c.Management.SignalURI = c.Server.SignalURI
} else if c.Management.SignalURI == "" {
// Auto-configure local signal
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
}
}
// LoadConfig loads configuration from a YAML file
func LoadConfig(configPath string) (*CombinedConfig, error) {
cfg := DefaultConfig()
if configPath == "" {
return cfg, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Populate internal configs from server settings
cfg.ApplySimplifiedDefaults()
return cfg, nil
}
// Validate validates the configuration
func (c *CombinedConfig) Validate() error {
if c.Server.ExposedAddress == "" {
return fmt.Errorf("server.exposedAddress is required")
}
if c.Server.DataDir == "" {
return fmt.Errorf("server.dataDir is required")
}
// Validate STUN ports
seen := make(map[int]bool)
for _, port := range c.Server.StunPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
}
seen[port] = true
}
// authSecret is required only if running local relay (no external relay configured)
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
if !hasExternalRelay && c.Server.AuthSecret == "" {
return fmt.Errorf("server.authSecret is required when running local relay")
}
return nil
}
// HasTLSCert returns true if TLS certificate files are configured
func (c *CombinedConfig) HasTLSCert() bool {
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
}
// HasLetsEncrypt returns true if Let's Encrypt is configured
func (c *CombinedConfig) HasLetsEncrypt() bool {
return c.Server.TLS.LetsEncrypt.Enabled &&
c.Server.TLS.LetsEncrypt.DataDir != "" &&
len(c.Server.TLS.LetsEncrypt.Domains) > 0
}
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
switch strings.ToLower(proto) {
case "udp":
return nbconfig.UDP, true
case "dtls":
return nbconfig.DTLS, true
case "tcp":
return nbconfig.TCP, true
case "http":
return nbconfig.HTTP, true
case "https":
return nbconfig.HTTPS, true
default:
return "", false
}
}
// parseStunProtocol determines protocol for STUN/TURN servers.
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
// Explicit proto overrides URI scheme. Defaults to UDP.
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
if proto != "" {
if p, ok := parseExplicitProtocol(proto); ok {
return p
}
}
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "stuns:"):
return nbconfig.DTLS
case strings.HasPrefix(uri, "turns:"):
return nbconfig.DTLS
default:
// stun:, turn:, or no scheme - default to UDP
return nbconfig.UDP
}
}
// parseSignalProtocol determines protocol for Signal servers.
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
func parseSignalProtocol(uri string) nbconfig.Protocol {
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "http://"):
return nbconfig.HTTP
default:
// https:// or no scheme - default to HTTPS
return nbconfig.HTTPS
}
}
// stripSignalProtocol removes the protocol prefix from a signal URI.
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
func stripSignalProtocol(uri string) string {
uri = strings.TrimPrefix(uri, "https://")
uri = strings.TrimPrefix(uri, "http://")
return uri
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
// Build STUN hosts
var stuns []*nbconfig.Host
for _, s := range mgmt.Stuns {
stuns = append(stuns, &nbconfig.Host{
URI: s.URI,
Proto: parseStunProtocol(s.URI, s.Proto),
Username: s.Username,
Password: s.Password,
})
}
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
}
}
// Build signal config
var signalConfig *nbconfig.Host
if mgmt.SignalURI != "" {
signalConfig = &nbconfig.Host{
URI: stripSignalProtocol(mgmt.SignalURI),
Proto: parseSignalProtocol(mgmt.SignalURI),
}
}
// Build store config
storeConfig := nbconfig.StoreConfig{
Engine: types.Engine(mgmt.Store.Engine),
}
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
}
}
for _, p := range mgmt.ReverseProxy.TrustedPeers {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
}
}
// Build HTTP config (required, even if empty)
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
}
// Set HTTP config fields for embedded IDP
httpConfig.AuthIssuer = mgmt.Auth.Issuer
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
return &nbconfig.Config{
Stuns: stuns,
Relay: relayConfig,
Signal: signalConfig,
Datadir: mgmt.DataDir,
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
HttpConfig: httpConfig,
StoreConfig: storeConfig,
ReverseProxy: reverseProxy,
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
EmbeddedIdP: embeddedIdP,
}, nil
}
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
// Embedded IdP requires single account mode
if disableSingleAccMode {
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
}
// Set LocalAddress for embedded IdP, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
// EnsureEncryptionKey generates an encryption key if not set.
// Unlike management server, we don't write back to the config file.
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
key, err := crypt.GenerateKey()
if err != nil {
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
}
cfg.DataStoreEncryptionKey = key
keyPreview := key[:8] + "..."
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
return nil
}
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
if cfg.Relay != nil {
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
}
}

33
combined/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

711
combined/cmd/root.go Normal file
View File

@@ -0,0 +1,711 @@
package cmd
import (
"context"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/coder/websocket"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/metric"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
)
var (
configPath string
config *CombinedConfig
rootCmd = &cobra.Command{
Use: "combined",
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
Long: `Combined Netbird server for self-hosted deployments.
All services (Management, Signal, Relay) are multiplexed on a single port.
Optional STUN server runs on separate UDP ports.
Configuration is loaded from a YAML file specified with --config.`,
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
}
func Execute() error {
return rootCmd.Execute()
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, _ []string) error {
if err := initializeConfig(); err != nil {
return err
}
// Management is required as the base server when signal or relay are enabled
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
}
servers, err := createAllServers(cmd.Context(), config)
if err != nil {
return err
}
// Register services with management's gRPC server using AfterInit hook
setupServerHooks(servers, config)
// Start management server (this also starts the HTTP listener)
if servers.mgmtSrv != nil {
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
cleanupSTUNListeners(servers.stunListeners)
return fmt.Errorf("failed to start management server: %w", err)
}
}
// Start all other servers
wg := sync.WaitGroup{}
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
wg.Wait()
return err
}
// initializeConfig loads and validates the configuration, then initializes logging.
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if err := config.Validate(); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
log.Infof("Starting combined NetBird server")
logConfig(config)
logEnvVars()
return nil
}
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv *mgmtServer.BaseServer
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
stunListeners []*net.UDPConn
metricsServer *sharedMetrics.Metrics
}
// createAllServers creates all server instances based on configuration.
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
if err != nil {
return nil, fmt.Errorf("failed to create metrics server: %w", err)
}
servers := &serverInstances{
metricsServer: metricsServer,
}
_, tlsSupport, err := handleTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
}
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
return nil, err
}
if err := servers.createManagementServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createSignalServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createHealthcheckServer(cfg); err != nil {
return nil, err
}
return servers, nil
}
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
if !cfg.Relay.Enabled {
return nil
}
var err error
s.stunListeners, err = createSTUNListeners(cfg)
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
relayCfg := relayServer.Config{
Meter: s.metricsServer.Meter,
ExposedAddress: cfg.Relay.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
if err != nil {
return err
}
log.Infof("Relay server created")
if len(s.stunListeners) > 0 {
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
}
return nil
}
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Management.Enabled {
return nil
}
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("failed to create management config: %w", err)
}
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
if portErr != nil {
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
}
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to ensure encryption key: %w", err)
}
LogConfigInfo(mgmtConfig)
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management server: %w", err)
}
// Inject externally-managed AppMetrics so management uses the shared metrics server
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management app metrics: %w", err)
}
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
log.Infof("Management server created")
return nil
}
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Signal.Enabled {
return nil
}
var err error
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create signal server: %w", err)
}
log.Infof("Signal server created")
return nil
}
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
hCfg := healthcheck.Config{
ListenAddress: cfg.Server.HealthcheckAddress,
ServiceChecker: s.relaySrv,
}
var err error
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
return err
}
// setupServerHooks registers services with management's gRPC server.
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
if servers.mgmtSrv == nil {
return
}
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
if srv != nil {
instanceURL := srv.InstanceURL()
log.Infof("Relay server instance URL: %s", instanceURL.String())
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if srv != nil {
if err := srv.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
}
if mgmtSrv != nil {
log.Infof("shutting down management and signal servers")
if err := mgmtSrv.Stop(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
}
}
if metricsServer != nil {
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
}
return errs
}
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
srv, err := relayServer.NewServer(cfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create relay server: %w", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cfg.Relay.Stun.Enabled {
for _, port := range cfg.Relay.Stun.Ports {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
}
stunListeners = append(stunListeners, listener)
log.Infof("STUN server listening on UDP port %d", port)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
tlsCfg := cfg.Server.TLS
if tlsCfg.LetsEncrypt.AWSRoute53 {
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
r53 := encryption.Route53TLS{
DataDir: tlsCfg.LetsEncrypt.DataDir,
Email: tlsCfg.LetsEncrypt.Email,
Domains: tlsCfg.LetsEncrypt.Domains,
}
tc, err := r53.GetCertificate()
if err != nil {
return nil, false, err
}
return tc, true, nil
}
if cfg.HasLetsEncrypt() {
log.Infof("setting up TLS with Let's Encrypt")
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
if err != nil {
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
}
return certManager.TLSConfig(), true, nil
}
if cfg.HasTLSCert() {
log.Debugf("using file based TLS config")
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
if err != nil {
return nil, false, err
}
return tc, true, nil
}
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
// If no port specified, assume default
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtConfig,
dnsDomain,
singleAccModeDomain,
mgmtPort,
cfg.Server.MetricsPort,
mgmt.DisableAnonymousMetrics,
mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
true,
)
return mgmtSrv, nil
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
// Native gRPC traffic (HTTP/2 with gRPC content-type)
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
grpcServer.ServeHTTP(w, r)
// WebSocket proxy for Management gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(w, r)
// WebSocket proxy for Signal gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
if cfg.Signal.Enabled {
wsProxy.Handler().ServeHTTP(w, r)
} else {
http.Error(w, "Signal service not enabled", http.StatusNotFound)
}
// Relay WebSocket
case r.URL.Path == "/relay":
if relayAcceptFn != nil {
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
} else {
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)
}
})
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept relay ws connection: %s", err)
return
}
connRemoteAddr := r.RemoteAddr
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr)
acceptFn(conn)
}
// logConfig prints all configuration parameters for debugging
func logConfig(cfg *CombinedConfig) {
log.Info("=== Configuration ===")
logServerConfig(cfg)
logComponentsConfig(cfg)
logRelayConfig(cfg)
logManagementConfig(cfg)
log.Info("=== End Configuration ===")
}
func logServerConfig(cfg *CombinedConfig) {
log.Info("--- Server ---")
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
log.Infof(" Log level: %s", cfg.Server.LogLevel)
log.Infof(" Data dir: %s", cfg.Server.DataDir)
switch {
case cfg.HasTLSCert():
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
case cfg.HasLetsEncrypt():
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
default:
log.Info(" TLS: disabled (using reverse proxy)")
}
}
func logComponentsConfig(cfg *CombinedConfig) {
log.Info("--- Components ---")
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
}
func logRelayConfig(cfg *CombinedConfig) {
if !cfg.Relay.Enabled {
return
}
log.Info("--- Relay ---")
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
if cfg.Relay.Stun.Enabled {
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
} else {
log.Info(" STUN: disabled")
}
}
func logManagementConfig(cfg *CombinedConfig) {
if !cfg.Management.Enabled {
return
}
log.Info("--- Management ---")
log.Infof(" Data dir: %s", cfg.Management.DataDir)
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
if cfg.Server.Store.DSN != "" {
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
}
log.Info(" Auth (embedded IdP):")
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
log.Info(" Client settings:")
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
for _, s := range cfg.Management.Stuns {
log.Infof(" STUN: %s", s.URI)
}
if len(cfg.Management.Relays.Addresses) > 0 {
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
}
}
// logEnvVars logs all NB_ environment variables that are currently set
func logEnvVars() {
log.Info("=== Environment Variables ===")
found := false
for _, env := range os.Environ() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
value = maskSecret(value)
}
log.Infof(" %s=%s", key, value)
found = true
}
}
if !found {
log.Info(" (none set)")
}
log.Info("=== End Environment Variables ===")
}
// maskDSNPassword masks the password in a DSN string.
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
func maskDSNPassword(dsn string) string {
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
if strings.Contains(dsn, "password=") {
parts := strings.Fields(dsn)
for i, p := range parts {
if strings.HasPrefix(p, "password=") {
parts[i] = "password=****"
}
}
return strings.Join(parts, " ")
}
// URI format: "user:password@host..."
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
prefix := dsn[:atIdx]
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
}
}
return dsn
}
// maskSecret returns first 4 chars of secret followed by "..."
func maskSecret(secret string) string {
if len(secret) <= 4 {
return "****"
}
return secret[:4] + "..."
}

219
combined/cmd/token.go Normal file
View File

@@ -0,0 +1,219 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
tokenCmd.AddCommand(tokenCreateCmd, tokenListCmd, tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
// Load combined server YAML config
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
// Get datadir from config or override
datadir := cfg.Server.DataDir
if tokenDatadir != "" {
datadir = tokenDatadir
}
// Get store engine from config
storeEngine := types.Engine(cfg.Server.Store.Engine)
if storeEngine == "" {
storeEngine = "sqlite"
}
s, err := store.NewStore(ctx, storeEngine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!") //nolint:forbidigo
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
fmt.Println() //nolint:forbidigo
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -0,0 +1,111 @@
# NetBird Combined Server Configuration
# Copy this file to config.yaml and customize for your deployment
#
# This is a Management server with optional embedded Signal, Relay, and STUN services.
# By default, all services run locally. You can use external services instead by
# setting the corresponding override fields.
#
# Architecture:
# - Management: Always runs locally (this IS the management server)
# - Signal: Local by default; set 'signalUri' to use external (disables local)
# - Relay: Local by default; set 'relays' to use external (disables local)
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Public address that peers will use to connect to this server
# Used for relay connections and management DNS domain
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
exposedAddress: "https://server.mycompany.com:443"
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
# stunPorts:
# - 3478
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Shared secret for relay authentication (required when running local relay)
authSecret: "your-secret-key-here"
# Data directory for all services
dataDir: "/var/lib/netbird/"
# ============================================================================
# External Service Overrides (optional)
# Use these to point to external Signal, Relay, or STUN servers instead of
# running them locally. When set, the corresponding local service is disabled.
# ============================================================================
# External STUN servers - disables local STUN server
# stuns:
# - uri: "stun:stun.example.com:3478"
# - uri: "stun:stun.example.com:3479"
# External relay servers - disables local relay server
# relays:
# addresses:
# - "rels://relay.example.com:443"
# credentialsTTL: "12h"
# secret: "relay-shared-secret"
# External signal server - disables local signal server
# signalUri: "https://signal.example.com:443"
# ============================================================================
# Management Settings
# ============================================================================
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
# Embedded authentication/identity provider (Dex) configuration (always enabled)
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://server.mycompany.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.netbird.io/nb-auth"
- "https://app.netbird.io/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []

View File

@@ -0,0 +1,115 @@
# Simplified Combined NetBird Server Configuration
# Copy this file to config.yaml and customize for your deployment
# Server-wide settings
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Relay service configuration
relay:
# Enable/disable the relay service
enabled: true
# Public address that peers will use to connect to this relay
# Format: hostname:port or ip:port
exposedAddress: "relay.example.com:443"
# Shared secret for relay authentication (required when enabled)
authSecret: "your-secret-key-here"
# Log level for relay (reserved for future use, currently uses global log level)
logLevel: "info"
# Embedded STUN server (optional)
stun:
enabled: false
ports: [3478]
logLevel: "info"
# Signal service configuration
signal:
# Enable/disable the signal service
enabled: true
# Log level for signal (reserved for future use, currently uses global log level)
logLevel: "info"
# Management service configuration
management:
# Enable/disable the management service
enabled: true
# Data directory for management service
dataDir: "/var/lib/netbird/"
# DNS domain for the management server
dnsDomain: ""
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://management.example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
- "https://app.example.com/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# External STUN servers (for client config)
stuns: []
# - uri: "stun:stun.example.com:3478"
# External relay servers (for client config)
relays:
addresses: []
# - "rels://relay.example.com:443"
credentialsTTL: "12h"
secret: ""
# External signal server URI (for client config)
signalUri: ""
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings
reverseProxy:
trustedHTTPProxies: []
trustedHTTPProxiesCount: 0
trustedPeers: []

13
combined/main.go Normal file
View File

@@ -0,0 +1,13 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/combined/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

3
go.mod
View File

@@ -40,7 +40,6 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1
@@ -70,7 +69,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0

6
go.sum
View File

@@ -107,8 +107,6 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
@@ -408,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
return nil
}
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return false, fmt.Errorf("failed to list connectors: %w", err)
}
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
for _, conn := range connectors {
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
if conn.ID != "local" || conn.Type != "local" {
p.logger.Info("found non-local connector", "id", conn.ID)
return true, nil
}
}
p.logger.Info("no non-local connectors found")
return false, nil
}
// DisableLocalAuth removes the local (password) connector.
// Returns an error if no other connectors are configured.
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
hasOthers, err := p.HasNonLocalConnectors(ctx)
if err != nil {
return err
}
if !hasOthers {
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
}
// Check if local connector exists
_, err = p.storage.GetConnector(ctx, "local")
if errors.Is(err, storage.ErrNotFound) {
// Already disabled
return nil
}
if err != nil {
return fmt.Errorf("failed to check local connector: %w", err)
}
// Delete the local connector
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
return fmt.Errorf("failed to delete local connector: %w", err)
}
p.logger.Info("local authentication disabled")
return nil
}
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
return ensureLocalConnector(ctx, p.storage)
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
FROM golang:1.25-bookworm AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -19,6 +19,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -55,7 +57,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
}
@@ -133,35 +135,35 @@ var (
}
)
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
applyCommandLineOverrides(loadedConfig)
ApplyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
return nil, err
}
logConfigInfo(loadedConfig)
LogConfigInfo(loadedConfig)
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
return loadedConfig, nil
}
// applyCommandLineOverrides applies command-line flag overrides to the config
func applyCommandLineOverrides(cfg *nbconfig.Config) {
// ApplyCommandLineOverrides applies command-line flag overrides to the config
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
if mgmtLetsencryptDomain != "" {
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
}
@@ -174,9 +176,9 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
}
}
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -213,17 +215,20 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
return nil
}
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" {
return nil
@@ -249,16 +254,16 @@ func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
return err
}
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
return nil
}
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
return nil
}
@@ -285,8 +290,8 @@ func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
return nil
}
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
if cfg.PKCEAuthorizationFlow == nil {
return
}
@@ -299,8 +304,8 @@ func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// logConfigInfo logs informational messages about the loaded configuration
func logConfigInfo(cfg *nbconfig.Config) {
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
@@ -309,8 +314,8 @@ func logConfigInfo(cfg *nbconfig.Config) {
}
}
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}

View File

@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
t.Fatalf("failed to create config: %s", err)
}
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
if err != nil {
t.Fatalf("failed to load management config: %s", err)
}

View File

@@ -80,4 +80,10 @@ func init() {
migrationCmd.AddCommand(upCmd)
rootCmd.AddCommand(migrationCmd)
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
tokenCmd.AddCommand(tokenCreateCmd)
tokenCmd.AddCommand(tokenListCmd)
tokenCmd.AddCommand(tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}

209
management/cmd/token.go Normal file
View File

@@ -0,0 +1,209 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!") //nolint:forbidigo
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
fmt.Println() //nolint:forbidigo
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -0,0 +1,101 @@
package cmd
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseDuration(t *testing.T) {
tests := []struct {
name string
input string
expected time.Duration
wantErr bool
}{
{
name: "empty string returns zero",
input: "",
expected: 0,
},
{
name: "days suffix",
input: "30d",
expected: 30 * 24 * time.Hour,
},
{
name: "one day",
input: "1d",
expected: 24 * time.Hour,
},
{
name: "365 days",
input: "365d",
expected: 365 * 24 * time.Hour,
},
{
name: "hours via Go duration",
input: "24h",
expected: 24 * time.Hour,
},
{
name: "minutes via Go duration",
input: "30m",
expected: 30 * time.Minute,
},
{
name: "complex Go duration",
input: "1h30m",
expected: 90 * time.Minute,
},
{
name: "invalid day format",
input: "abcd",
wantErr: true,
},
{
name: "negative days",
input: "-1d",
wantErr: true,
},
{
name: "zero days",
input: "0d",
wantErr: true,
},
{
name: "non-numeric days",
input: "xyzd",
wantErr: true,
},
{
name: "negative Go duration",
input: "-24h",
wantErr: true,
},
{
name: "zero Go duration",
input: "0s",
wantErr: true,
},
{
name: "invalid Go duration",
input: "notaduration",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseDuration(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -174,14 +174,13 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
exposedServices := account.GetExposedServicesMap()
proxyPeers := account.GetProxyPeers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
@@ -234,7 +233,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs, exposedServices, proxyPeers)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -249,7 +248,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
@@ -325,6 +327,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return fmt.Errorf("failed to get validated peers: %v", err)
}
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -355,7 +358,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs, account.GetExposedServicesMap(), account.GetProxyPeers())
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -372,7 +375,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
return nil
}
@@ -437,6 +443,8 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -471,7 +479,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers(), account.GetExposedServicesMap(), account.GetProxyPeers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -780,6 +788,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
},
},
},
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
@@ -842,9 +851,10 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers(), account.GetExposedServicesMap(), account.GetProxyPeers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
},
}}
MessageType: network_map.MessageTypeNetworkMap,
}
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
},
}}
MessageType: network_map.MessageTypeNetworkMap,
}
peersUpdater.SendUpdate(context.Background(), peer, update2)
timeout := time.After(5 * time.Second)

View File

@@ -4,6 +4,19 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
// MessageType indicates the type of update message for debouncing strategy
type MessageType int
const (
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
// These updates can be safely debounced - only the latest state matters
MessageTypeNetworkMap MessageType = iota
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
// These updates should not be dropped as they contain time-sensitive information
MessageTypeControlConfig
)
type UpdateMessage struct {
Update *proto.SyncResponse
Update *proto.SyncResponse
MessageType MessageType
}

View File

@@ -33,7 +33,7 @@ type Manager interface {
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string) error
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
}
type managerImpl struct {
@@ -185,7 +185,7 @@ func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, er
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
}
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string) error {
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
if err == nil && existingPeerID != "" {
// Peer already exists
@@ -194,8 +194,11 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
name := fmt.Sprintf("proxy-%s", xid.New().String())
peer := &peer.Peer{
Ephemeral: true,
ProxyEmbedded: true,
Ephemeral: true,
ProxyMeta: peer.ProxyMeta{
Cluster: cluster,
Embedded: true,
},
Name: name,
Key: peerKey,
LoginExpirationEnabled: false,

View File

@@ -162,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -13,42 +13,42 @@ import (
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ProxyID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string
Host string
Path string
Duration time.Duration
StatusCode int
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string
AuthMethodUsed string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
func (a *AccessLogEntry) FromProto(proxyLog *proto.AccessLog) {
a.ID = proxyLog.GetLogId()
a.ProxyID = proxyLog.GetServiceId()
a.Timestamp = proxyLog.GetTimestamp().AsTime()
a.Method = proxyLog.GetMethod()
a.Host = proxyLog.GetHost()
a.Path = proxyLog.GetPath()
a.Duration = time.Duration(proxyLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(proxyLog.GetResponseCode())
a.UserId = proxyLog.GetUserId()
a.AuthMethodUsed = proxyLog.GetAuthMechanism()
a.AccountID = proxyLog.GetAccountId()
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.ID = serviceLog.GetLogId()
a.ServiceID = serviceLog.GetServiceId()
a.Timestamp = serviceLog.GetTimestamp().AsTime()
a.Method = serviceLog.GetMethod()
a.Host = serviceLog.GetHost()
a.Path = serviceLog.GetPath()
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(serviceLog.GetResponseCode())
a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId()
if sourceIP := proxyLog.GetSourceIp(); sourceIP != "" {
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
}
}
if !proxyLog.GetAuthSuccess() {
if !serviceLog.GetAuthSuccess() {
a.Reason = "Authentication failed"
} else if proxyLog.GetResponseCode() >= 400 {
} else if serviceLog.GetResponseCode() >= 400 {
a.Reason = "Request failed"
}
}
@@ -88,7 +88,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
return &api.ProxyAccessLog{
Id: a.ID,
ProxyId: a.ProxyID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,

View File

@@ -0,0 +1,124 @@
package accesslogs
import (
"net/http"
"strconv"
"time"
)
const (
// DefaultPageSize is the default number of records per page
DefaultPageSize = 50
// MaxPageSize is the maximum number of records allowed per page
MaxPageSize = 100
)
// AccessLogFilter holds pagination and filtering parameters for access logs
type AccessLogFilter struct {
// Page is the current page number (1-indexed)
Page int
// PageSize is the number of records per page
PageSize int
// Filtering parameters
Search *string // General search across log ID, host, path, source IP, and user fields
SourceIP *string // Filter by source IP address
Host *string // Filter by host header
Path *string // Filter by request path (supports LIKE pattern)
UserID *string // Filter by authenticated user ID
UserEmail *string // Filter by user email (requires user lookup)
UserName *string // Filter by user name (requires user lookup)
Method *string // Filter by HTTP method
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
StatusCode *int // Filter by HTTP status code
StartDate *time.Time // Filter by timestamp >= start_date
EndDate *time.Time // Filter by timestamp <= end_date
}
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
queryParams := r.URL.Query()
f.Page = 1
if pageStr := queryParams.Get("page"); pageStr != "" {
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
f.Page = page
}
}
f.PageSize = DefaultPageSize
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
f.PageSize = pageSize
if f.PageSize > MaxPageSize {
f.PageSize = MaxPageSize
}
}
}
if search := queryParams.Get("search"); search != "" {
f.Search = &search
}
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
f.SourceIP = &sourceIP
}
if host := queryParams.Get("host"); host != "" {
f.Host = &host
}
if path := queryParams.Get("path"); path != "" {
f.Path = &path
}
if userID := queryParams.Get("user_id"); userID != "" {
f.UserID = &userID
}
if userEmail := queryParams.Get("user_email"); userEmail != "" {
f.UserEmail = &userEmail
}
if userName := queryParams.Get("user_name"); userName != "" {
f.UserName = &userName
}
if method := queryParams.Get("method"); method != "" {
f.Method = &method
}
if status := queryParams.Get("status"); status != "" {
f.Status = &status
}
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
f.StatusCode = &statusCode
}
}
if startDate := queryParams.Get("start_date"); startDate != "" {
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
if err == nil {
f.StartDate = &parsedStartDate
}
}
if endDate := queryParams.Get("end_date"); endDate != "" {
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
if err == nil {
f.EndDate = &parsedEndDate
}
}
}
// GetOffset calculates the database offset for pagination
func (f *AccessLogFilter) GetOffset() int {
return (f.Page - 1) * f.PageSize
}
// GetLimit returns the page size for database queries
func (f *AccessLogFilter) GetLimit() int {
return f.PageSize
}

View File

@@ -0,0 +1,161 @@
package accesslogs
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
tests := []struct {
name string
queryParams map[string]string
expectedPage int
expectedPageSize int
}{
{
name: "default values when no params provided",
queryParams: map[string]string{},
expectedPage: 1,
expectedPageSize: DefaultPageSize,
},
{
name: "valid page and page_size",
queryParams: map[string]string{
"page": "2",
"page_size": "25",
},
expectedPage: 2,
expectedPageSize: 25,
},
{
name: "page_size exceeds max, should cap at MaxPageSize",
queryParams: map[string]string{
"page": "1",
"page_size": "200",
},
expectedPage: 1,
expectedPageSize: MaxPageSize,
},
{
name: "invalid page number, should use default",
queryParams: map[string]string{
"page": "invalid",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "invalid page_size, should use default",
queryParams: map[string]string{
"page": "2",
"page_size": "invalid",
},
expectedPage: 2,
expectedPageSize: DefaultPageSize,
},
{
name: "zero page number, should use default",
queryParams: map[string]string{
"page": "0",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "negative page number, should use default",
queryParams: map[string]string{
"page": "-1",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "zero page_size, should use default",
queryParams: map[string]string{
"page": "1",
"page_size": "0",
},
expectedPage: 1,
expectedPageSize: DefaultPageSize,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
q := req.URL.Query()
for key, value := range tt.queryParams {
q.Set(key, value)
}
req.URL.RawQuery = q.Encode()
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
})
}
}
func TestAccessLogFilter_GetOffset(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
expectedOffset int
}{
{
name: "first page",
page: 1,
pageSize: 50,
expectedOffset: 0,
},
{
name: "second page",
page: 2,
pageSize: 50,
expectedOffset: 50,
},
{
name: "third page with page size 25",
page: 3,
pageSize: 25,
expectedOffset: 50,
},
{
name: "page 10 with page size 10",
page: 10,
pageSize: 10,
expectedOffset: 90,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filter := &AccessLogFilter{
Page: tt.page,
PageSize: tt.pageSize,
}
offset := filter.GetOffset()
assert.Equal(t, tt.expectedOffset, offset)
})
}
}
func TestAccessLogFilter_GetLimit(t *testing.T) {
filter := &AccessLogFilter{
Page: 2,
PageSize: 25,
}
limit := filter.GetLimit()
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
}

View File

@@ -6,5 +6,5 @@ import (
type Manager interface {
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
GetAllAccessLogs(ctx context.Context, accountID, userID string) ([]*AccessLogEntry, error)
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
}

View File

@@ -30,7 +30,10 @@ func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
return
}
logs, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId)
var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r)
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -41,5 +44,21 @@ func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
apiLogs = append(apiLogs, *log.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiLogs)
response := &api.ProxyAccessLogsResponse{
Data: apiLogs,
Page: filter.Page,
PageSize: filter.PageSize,
TotalRecords: int(totalCount),
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
}
util.WriteJSONObject(r.Context(), w, response)
}
// getTotalPageCount calculates the total number of pages
func getTotalPageCount(totalCount, pageSize int) int {
if pageSize <= 0 {
return 0
}
return (totalCount + pageSize - 1) / pageSize
}

View File

@@ -2,6 +2,7 @@ package manager
import (
"context"
"strings"
log "github.com/sirupsen/logrus"
@@ -43,11 +44,11 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"proxy_id": logEntry.ProxyID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
"service_id": logEntry.ServiceID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
}).Errorf("failed to save access log: %v", err)
return err
}
@@ -55,20 +56,53 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
return nil
}
// GetAllAccessLogs retrieves all access logs for an account
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string) ([]*accesslogs.AccessLogEntry, error) {
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
return nil, 0, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
return nil, 0, status.NewPermissionDeniedError()
}
logs, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID)
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
}
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
if err != nil {
return nil, err
return nil, 0, err
}
return logs, nil
return logs, totalCount, nil
}
// resolveUserFilters converts user email/name filters to user ID filter
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
if filter.UserEmail == nil && filter.UserName == nil {
return nil
}
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
var matchingUserIDs []string
for _, user := range users {
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
matchingUserIDs = append(matchingUserIDs, user.Id)
continue
}
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
matchingUserIDs = append(matchingUserIDs, user.Id)
}
}
if len(matchingUserIDs) > 0 {
filter.UserID = &matchingUserIDs[0]
}
return nil
}

View File

@@ -0,0 +1,17 @@
package domain
type Type string
const (
TypeFree Type = "free"
TypeCustom Type = "custom"
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
TargetCluster string // The proxy cluster this domain should be validated against
Type Type `gorm:"-"`
Validated bool
}

View File

@@ -0,0 +1,12 @@
package domain
import (
"context"
)
type Manager interface {
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
}

View File

@@ -1,178 +0,0 @@
package domain
import (
"context"
"fmt"
"net"
"net/url"
"github.com/netbirdio/netbird/management/server/types"
log "github.com/sirupsen/logrus"
)
type domainType string
const (
TypeFree domainType = "free"
TypeCustom domainType = "custom"
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
Type domainType `gorm:"-"`
Validated bool
}
type store interface {
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *Domain) (*Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
}
type Manager struct {
store store
validator Validator
proxyURLProvider proxyURLProvider
}
func NewManager(store store, proxyURLProvider proxyURLProvider) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
validator: Validator{
resolver: net.DefaultResolver,
},
}
}
func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, error) {
account, err := m.store.GetAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
free, err := m.store.ListFreeDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list free domains: %w", err)
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
}
var ret []*Domain
// Populate all fields correctly for custom domains that are retrieved.
for _, domain := range domains {
ret = append(ret, &Domain{
ID: domain.ID,
Domain: domain.Domain,
AccountID: accountID,
Type: TypeCustom,
Validated: domain.Validated,
})
}
// Prepend each free domain with the account nonce and then add it to the domain
// array to be returned.
// This account nonce is added to free domains to prevent users being able to
// query free domain usage across accounts and simplifies tracking free domain
// usage across accounts.
for _, name := range free {
ret = append(ret, &Domain{
Domain: account.ReverseProxyFreeDomainNonce + "." + name,
AccountID: accountID,
Type: TypeFree,
Validated: true,
})
}
return ret, nil
}
func (m Manager) CreateDomain(ctx context.Context, accountID, domainName string) (*Domain, error) {
// Attempt an initial validation; however, a failure is still acceptable for creation
// because the user may not yet have configured their DNS records, or the DNS update
// has not yet reached the servers that are queried by the validation resolver.
var validated bool
if m.validator.IsValid(ctx, domainName, m.proxyURLAllowList()) {
validated = true
}
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, validated)
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
return d, nil
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) error {
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
}
return nil
}
func (m Manager) ValidateDomain(accountID, domainID string) {
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("get custom domain from store")
return
}
if m.validator.IsValid(context.Background(), d.Domain, m.proxyURLAllowList()) {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Debug("domain validated successfully")
d.Validated = true
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).WithError(err).Error("update custom domain in store")
return
}
}
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs (as reported by the proxy servers). It performs some clean
// up on those URLs to attempt to retrieve domain names as we would
// expect to see them in a validation check.
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
}
var allowedProxyURLs []string
for _, addr := range reverseProxyAddresses {
proxyUrl, err := url.Parse(addr)
if err != nil {
// TODO: log?
continue
}
host, _, err := net.SplitHostPort(proxyUrl.Host)
if err != nil {
// TODO: log?
host = proxyUrl.Host
}
allowedProxyURLs = append(allowedProxyURLs, host)
}
return allowedProxyURLs
}

View File

@@ -1,4 +1,4 @@
package domain
package manager
import (
"encoding/json"
@@ -6,6 +6,7 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -27,11 +28,11 @@ func RegisterEndpoints(router *mux.Router, manager Manager) {
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
}
func domainTypeToApi(t domainType) api.ReverseProxyDomainType {
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
switch t {
case TypeCustom:
case domain.TypeCustom:
return api.ReverseProxyDomainTypeCustom
case TypeFree:
case domain.TypeFree:
return api.ReverseProxyDomainTypeFree
}
// By default return as a "free" domain as that is more restrictive.
@@ -39,13 +40,17 @@ func domainTypeToApi(t domainType) api.ReverseProxyDomainType {
return api.ReverseProxyDomainTypeFree
}
func domainToApi(d *Domain) api.ReverseProxyDomain {
return api.ReverseProxyDomain{
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
resp := api.ReverseProxyDomain{
Domain: d.Domain,
Id: d.ID,
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster
}
return resp
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
@@ -55,7 +60,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId)
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -82,7 +87,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
return
}
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, req.Domain)
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -104,7 +109,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, domainID); err != nil {
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -125,7 +130,7 @@ func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.R
return
}
go h.manager.ValidateDomain(userAuth.AccountId, domainID)
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
w.WriteHeader(http.StatusAccepted)
}

View File

@@ -0,0 +1,279 @@
package manager
import (
"context"
"fmt"
"net"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
type store interface {
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
}
type Manager struct {
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
permissionsManager permissions.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
permissionsManager: permissionsManager,
}
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
}
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
for _, cluster := range allowList {
ret = append(ret, &domain.Domain{
Domain: cluster,
AccountID: accountID,
Type: domain.TypeFree,
Validated: true,
})
}
// Add custom domains.
for _, d := range domains {
ret = append(ret, &domain.Domain{
ID: d.ID,
Domain: d.Domain,
AccountID: accountID,
TargetCluster: d.TargetCluster,
Type: domain.TypeCustom,
Validated: d.Validated,
})
}
return ret, nil
}
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
clusterValid = true
break
}
}
if !clusterValid {
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
}
// Attempt an initial validation against the specified cluster only
var validated bool
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
validated = true
}
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
return d, nil
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
}
return nil
}
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).Info("starting domain validation")
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("get custom domain from store")
return
}
// Validate only against the domain's target cluster
targetCluster := d.TargetCluster
if targetCluster == "" {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Warn("domain has no target cluster set, skipping validation")
return
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"targetCluster": targetCluster,
}).Info("validating domain against target cluster")
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Info("domain validated successfully")
d.Validated = true
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).WithError(err).Error("update custom domain in store")
return
}
} else {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"targetCluster": targetCluster,
}).Warn("domain validation failed - CNAME does not match target cluster")
}
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
}
return reverseProxyAddresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
return cluster, nil
}
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return "", fmt.Errorf("list custom domains: %w", err)
}
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
if valid {
return targetCluster, nil
}
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
for _, customDomain := range customDomains {
if strings.HasSuffix(domain, "."+customDomain.Domain) {
return customDomain.TargetCluster, true
}
}
return "", false
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
// It matches the domain suffix against available clusters and returns the matching cluster.
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
for _, cluster := range availableClusters {
if strings.HasSuffix(domain, "."+cluster) {
return cluster, true
}
}
return "", false
}

View File

@@ -13,15 +13,15 @@ type resolver interface {
}
type Validator struct {
resolver resolver
Resolver resolver
}
// NewValidator initializes a validator with a specific DNS resolver.
// If a Validator is used without specifying a resolver, then it will
// NewValidator initializes a validator with a specific DNS Resolver.
// If a Validator is used without specifying a Resolver, then it will
// use the net.DefaultResolver.
func NewValidator(resolver resolver) *Validator {
return &Validator{
resolver: resolver,
Resolver: resolver,
}
}
@@ -32,28 +32,57 @@ func NewValidator(resolver resolver) *Validator {
// The comparison is very simple, so wildcards will not match if included
// in the acceptable domain list.
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
if v.resolver == nil {
v.resolver = net.DefaultResolver
_, valid := v.ValidateWithCluster(ctx, domain, accept)
return valid
}
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
// Returns the cluster address and true if valid, or empty string and false if invalid.
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
if v.Resolver == nil {
v.Resolver = net.DefaultResolver
}
// Prepend subdomain for ownership validation because we want to check
// for the record being a wildcard ("*.example.com"), but you cannot
// look up a wildcard so we have to add a subdomain for the check.
cname, err := v.resolver.LookupCNAME(ctx, "validation."+domain)
lookupDomain := "validation." + domain
log.WithFields(log.Fields{
"domain": domain,
"lookupDomain": lookupDomain,
"acceptList": accept,
}).Debug("looking up CNAME for domain validation")
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
}).WithError(err).Error("Error resolving CNAME from resolver")
return false
"domain": domain,
"lookupDomain": lookupDomain,
}).WithError(err).Warn("CNAME lookup failed for domain validation")
return "", false
}
// Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing).
nakedCNAME := strings.TrimSuffix(cname, ".")
for _, domain := range accept {
// Currently, the match is a very simple string comparison.
if nakedCNAME == strings.TrimSuffix(domain, ".") {
return true
log.WithFields(log.Fields{
"domain": domain,
"cname": cname,
"nakedCNAME": nakedCNAME,
"acceptList": accept,
}).Debug("CNAME lookup result for domain validation")
for _, acceptDomain := range accept {
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
if nakedCNAME == normalizedAccept {
log.WithFields(log.Fields{
"domain": domain,
"cname": nakedCNAME,
"cluster": acceptDomain,
}).Info("domain CNAME matched cluster")
return acceptDomain, true
}
}
return false
log.WithFields(log.Fields{
"domain": domain,
"cname": nakedCNAME,
"acceptList": accept,
}).Warn("domain CNAME does not match any accepted cluster")
return "", false
}

View File

@@ -1,15 +1,23 @@
package reverseproxy
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
)
type Manager interface {
GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*ReverseProxy, error)
GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*ReverseProxy, error)
CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *ReverseProxy) (*ReverseProxy, error)
UpdateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *ReverseProxy) (*ReverseProxy, error)
DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error
SetStatus(ctx context.Context, accountID, reverseProxyID string, status ProxyStatus) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
}

View File

@@ -0,0 +1,225 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateService mocks base method.
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateService indicates an expected call of CreateService.
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteService indicates an expected call of DeleteService.
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
}
// GetAccountServices mocks base method.
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountServices indicates an expected call of GetAccountServices.
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllServices indicates an expected call of GetAllServices.
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGlobalServices indicates an expected call of GetGlobalServices.
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
}
// GetService mocks base method.
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetService indicates an expected call of GetService.
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
}
// GetServiceByID mocks base method.
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByID indicates an expected call of GetServiceByID.
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
}
// GetServiceIDByTargetID mocks base method.
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
}
// ReloadAllServicesForAccount mocks base method.
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
}
// ReloadService mocks base method.
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadService indicates an expected call of ReloadService.
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
return ret0
}
// SetStatus indicates an expected call of SetStatus.
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateService indicates an expected call of UpdateService.
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -20,148 +20,148 @@ type handler struct {
manager reverseproxy.Manager
}
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
// Hang domain endpoints off the main router here.
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domain.RegisterEndpoints(domainRouter, domainManager)
domainmanager.RegisterEndpoints(domainRouter, domainManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies", h.getAllReverseProxies).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies", h.createReverseProxy).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.getReverseProxy).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.updateReverseProxy).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.deleteReverseProxy).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllReverseProxies(w http.ResponseWriter, r *http.Request) {
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allReverseProxies, err := h.manager.GetAllReverseProxies(r.Context(), userAuth.AccountId, userAuth.UserId)
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiReverseProxies := make([]*api.ReverseProxy, 0, len(allReverseProxies))
for _, reverseProxy := range allReverseProxies {
apiReverseProxies = append(apiReverseProxies, reverseProxy.ToAPIResponse())
apiServices := make([]*api.Service, 0, len(allServices))
for _, service := range allServices {
apiServices = append(apiServices, service.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiReverseProxies)
util.WriteJSONObject(r.Context(), w, apiServices)
}
func (h *handler) createReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ReverseProxyRequest
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
reverseProxy := new(reverseproxy.ReverseProxy)
reverseProxy.FromAPIRequest(&req, userAuth.AccountId)
service := new(reverseproxy.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = reverseProxy.Validate(); err != nil {
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdReverseProxy, err := h.manager.CreateReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxy)
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdReverseProxy.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
}
func (h *handler) getReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
reverseProxy, err := h.manager.GetReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxyID)
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, reverseProxy.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
}
func (h *handler) updateReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
var req api.ReverseProxyRequest
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
reverseProxy := new(reverseproxy.ReverseProxy)
reverseProxy.ID = reverseProxyID
reverseProxy.FromAPIRequest(&req, userAuth.AccountId)
service := new(reverseproxy.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = reverseProxy.Validate(); err != nil {
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedReverseProxy, err := h.manager.UpdateReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxy)
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedReverseProxy.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
}
func (h *handler) deleteReverseProxy(w http.ResponseWriter, r *http.Request) {
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
if err := h.manager.DeleteReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxyID); err != nil {
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -17,25 +19,33 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
}
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
tokenStore *nbgrpc.OneTimeTokenStore
clusterDeriver ClusterDeriver
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore) reverseproxy.Manager {
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
tokenStore: tokenStore,
clusterDeriver: clusterDeriver,
}
}
func (m *managerImpl) GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -44,10 +54,58 @@ func (m *managerImpl) GetAllReverseProxies(ctx context.Context, accountID, userI
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -56,10 +114,19 @@ func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, re
return nil, status.NewPermissionDeniedError()
}
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -68,34 +135,49 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
return nil, status.NewPermissionDeniedError()
}
authConfig := reverseProxy.Auth
var proxyCluster string
if m.clusterDeriver != nil {
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
}
reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, reverseProxy.Targets, reverseProxy.Enabled)
reverseProxy.Auth = authConfig
service.AccountID = accountID
service.ProxyCluster = proxyCluster
service.InitNewRecord()
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
// Generate session JWT signing keys
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return nil, fmt.Errorf("generate session keys: %w", err)
}
reverseProxy.SessionPrivateKey = keyPair.PrivateKey
reverseProxy.SessionPublicKey = keyPair.PublicKey
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Check for duplicate domain
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing reverse proxy: %w", err)
return fmt.Errorf("failed to check existing service: %w", err)
}
}
if existingReverseProxy != nil {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
if existingService != nil {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if err = transaction.CreateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("failed to create reverse proxy: %w", err)
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
@@ -104,19 +186,21 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
return nil, err
}
token, err := m.tokenStore.GenerateToken(accountID, reverseProxy.ID, 5*time.Minute)
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to generate authentication token: %w", err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyCreated, reverseProxy.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()))
m.accountManager.UpdateAccountPeers(ctx, accountID)
return reverseProxy, nil
return service, nil
}
func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -125,28 +209,69 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
return nil, status.NewPermissionDeniedError()
}
var oldCluster string
var domainChanged bool
var serviceEnabledChanged bool
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Get existing reverse proxy
existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID)
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
// Check if domain changed and if it conflicts
if existingReverseProxy.Domain != reverseProxy.Domain {
conflictReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
oldCluster = existingService.ProxyCluster
if existingService.Domain != service.Domain {
domainChanged = true
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing reverse proxy: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
}
if conflictReverseProxy != nil && conflictReverseProxy.ID != reverseProxy.ID {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
if conflictService != nil && conflictService.ID != service.ID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
}
service.ProxyCluster = newCluster
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err = transaction.UpdateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("failed to update reverse proxy: %w", err)
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
serviceEnabledChanged = existingService.Enabled != service.Enabled
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
@@ -155,14 +280,54 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta())
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()))
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return reverseProxy, nil
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case domainChanged && oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error {
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -171,16 +336,16 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID,
return status.NewPermissionDeniedError()
}
var reverseProxy *reverseproxy.ReverseProxy
var service *reverseproxy.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
reverseProxy, err = transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteReverseProxy(ctx, accountID, reverseProxyID); err != nil {
return fmt.Errorf("failed to delete reverse proxy: %w", err)
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
@@ -189,46 +354,147 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID,
return err
}
m.accountManager.StoreEvent(ctx, userID, reverseProxyID, accountID, activity.ReverseProxyDeleted, reverseProxy.EventMeta())
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()))
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
proxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get reverse proxy: %w", err)
return fmt.Errorf("failed to get service: %w", err)
}
proxy.Meta.CertificateIssuedAt = time.Now()
service.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateReverseProxy(ctx, proxy); err != nil {
return fmt.Errorf("failed to update reverse proxy certificate timestamp: %w", err)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the reverse proxy (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
proxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get reverse proxy: %w", err)
return fmt.Errorf("failed to get service: %w", err)
}
proxy.Meta.Status = string(status)
service.Meta.Status = string(status)
if err = transaction.UpdateReverseProxy(ctx, proxy); err != nil {
return fmt.Errorf("failed to update reverse proxy status: %w", err)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}

View File

@@ -2,15 +2,18 @@ package reverseproxy
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/netbirdio/netbird/util/crypt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -33,18 +36,23 @@ const (
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
TargetTypePeer = "peer"
TargetTypeResource = "resource"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"`
Port int `json:"port"`
Protocol string `json:"protocol"`
TargetId string `json:"target_id"`
TargetType string `json:"target_type"`
Enabled bool `json:"enabled"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
@@ -68,6 +76,35 @@ type AuthConfig struct {
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
@@ -75,119 +112,147 @@ type OIDCValidationConfig struct {
MaxTokenAgeSeconds int64
}
type ReverseProxyMeta struct {
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type ReverseProxy struct {
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
Targets []Target `gorm:"serializer:json"`
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy {
return &ReverseProxy{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
Targets: targets,
Enabled: enabled,
Meta: ReverseProxyMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
},
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
authConfig := api.ReverseProxyAuthConfig{}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
if r.Auth.PasswordAuth != nil {
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: r.Auth.PasswordAuth.Enabled,
Password: r.Auth.PasswordAuth.Password,
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if r.Auth.PinAuth != nil {
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: r.Auth.PinAuth.Enabled,
Pin: r.Auth.PinAuth.Pin,
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if r.Auth.BearerAuth != nil {
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: r.Auth.BearerAuth.Enabled,
DistributionGroups: &r.Auth.BearerAuth.DistributionGroups,
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
for _, target := range r.Targets {
apiTargets = append(apiTargets, api.ReverseProxyTarget{
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: target.Host,
Host: &target.Host,
Port: target.Port,
Protocol: api.ReverseProxyTargetProtocol(target.Protocol),
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ReverseProxyTargetTargetType(target.TargetType),
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ReverseProxyMeta{
CreatedAt: r.Meta.CreatedAt,
Status: api.ReverseProxyMetaStatus(r.Meta.Status),
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if !r.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &r.Meta.CertificateIssuedAt
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
return &api.ReverseProxy{
Id: r.ID,
Name: r.Name,
Domain: r.Domain,
Targets: apiTargets,
Enabled: r.Enabled,
Auth: authConfig,
Meta: meta,
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(r.Targets))
for _, target := range r.Targets {
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: path,
}
if target.Port > 0 {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
@@ -195,30 +260,32 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid
}
auth := &proto.Authentication{
SessionKey: r.SessionPublicKey,
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if r.Auth.PinAuth != nil && r.Auth.PinAuth.Enabled {
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: r.ID,
Domain: r.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: r.AccountID,
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
@@ -236,36 +303,54 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
}
}
func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID string) {
r.Name = req.Name
r.Domain = req.Domain
r.AccountID = accountID
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
targets := make([]Target, 0, len(req.Targets))
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
targets = append(targets, Target{
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Host: apiTarget.Host,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
})
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
}
r.Targets = targets
s.Targets = targets
r.Enabled = req.Enabled
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
r.Auth.PasswordAuth = &PasswordAuthConfig{
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
r.Auth.PinAuth = &PINAuthConfig{
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
@@ -278,59 +363,81 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
r.Auth.BearerAuth = bearerAuth
s.Auth.BearerAuth = bearerAuth
}
}
func (r *ReverseProxy) Validate() error {
if r.Name == "" {
return errors.New("reverse proxy name is required")
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(r.Name) > 255 {
return errors.New("reverse proxy name exceeds maximum length of 255 characters")
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if r.Domain == "" {
return errors.New("reverse proxy domain is required")
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(r.Targets) == 0 {
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
}
return nil
}
func (r *ReverseProxy) EventMeta() map[string]any {
return map[string]any{"name": r.Name, "domain": r.Domain}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (r *ReverseProxy) Copy() *ReverseProxy {
targets := make([]Target, len(r.Targets))
copy(targets, r.Targets)
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &ReverseProxy{
ID: r.ID,
AccountID: r.AccountID,
Name: r.Name,
Domain: r.Domain,
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: r.Enabled,
Auth: r.Auth,
Meta: r.Meta,
SessionPrivateKey: r.SessionPrivateKey,
SessionPublicKey: r.SessionPublicKey,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
}
}
func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if r.SessionPrivateKey != "" {
if s.SessionPrivateKey != "" {
var err error
r.SessionPrivateKey, err = enc.Encrypt(r.SessionPrivateKey)
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
@@ -339,14 +446,14 @@ func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
return nil
}
func (r *ReverseProxy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if r.SessionPrivateKey != "" {
if s.SessionPrivateKey != "" {
var err error
r.SessionPrivateKey, err = enc.Decrypt(r.SessionPrivateKey)
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}

View File

@@ -0,0 +1,405 @@
package reverseproxy
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}

View File

@@ -8,7 +8,6 @@ import (
"net/http"
"net/netip"
"slices"
"strings"
"time"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
@@ -95,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -123,11 +122,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
s.proxyAuthClose = proxyAuthClose
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
}
if s.Config.HttpConfig.LetsEncryptDomain != "" {
@@ -162,7 +163,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.Store(), s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager())
})
@@ -172,18 +173,12 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
return Create(s, func() nbgrpc.ProxyOIDCConfig {
// TODO: this is weird, double check
// Build callback URL - this should be the management server's callback endpoint
// For embedded IdP, derive from issuer. For external, use a configured value or derive from issuer.
// The callback URL should be registered in the IdP's allowed redirect URIs for the dashboard client.
callbackURL := strings.TrimSuffix(s.Config.HttpConfig.AuthIssuer, "/oauth2")
callbackURL = callbackURL + "/api/oauth/callback"
return nbgrpc.ProxyOIDCConfig{
Issuer: s.Config.HttpConfig.AuthIssuer,
ClientID: "netbird-dashboard", // Reuse dashboard client
Issuer: s.Config.HttpConfig.AuthIssuer,
// todo: double check auth clientID value
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
Scopes: []string{"openid", "profile", "email"},
CallbackURL: callbackURL,
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
Audience: s.Config.HttpConfig.AuthAudience,
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,

View File

@@ -100,6 +100,8 @@ type HttpServerConfig struct {
CertFile string
// CertKey is the location of the certificate private key
CertKey string
// AuthClientID is the client id used for proxy SSO auth
AuthClientID string
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
AuthAudience string
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
@@ -117,6 +119,8 @@ type HttpServerConfig struct {
IdpSignKeyRefreshEnabled bool
// Extra audience
ExtraAuthAudience string
// AuthCallbackDomain contains the callback domain
AuthCallbackURL string
}
// Host represents a Netbird host (e.g. STUN, TURN, Signal)

View File

@@ -9,7 +9,7 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -72,7 +72,14 @@ func (s *BaseServer) UsersManager() users.Manager {
func (s *BaseServer) SettingsManager() settings.Manager {
return Create(s, func() settings.Manager {
extraSettingsManager := integrations.NewManager(s.EventStore())
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
idpConfig := settings.IdpConfig{}
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpConfig.EmbeddedIdpEnabled = true
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
}
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
})
}
@@ -94,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager())
})
return accountManager
})
}
@@ -150,7 +162,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
})
}
@@ -180,12 +192,13 @@ func (s *BaseServer) RecordsManager() records.Manager {
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ProxyTokenStore())
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ReverseProxyDomainManager() domain.Manager {
return Create(s, func() domain.Manager {
return domain.NewManager(s.Store(), s.ReverseProxyGRPCServer())
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
return &m
})
}

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
@@ -21,6 +20,7 @@ import (
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util/wsproxy"
@@ -58,6 +58,8 @@ type BaseServer struct {
mgmtMetricsPort int
mgmtPort int
proxyAuthClose func()
listener net.Listener
certManager *autocert.Manager
update *version.Update
@@ -138,6 +140,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
go metricsWorker.Run(srvCtx)
}
// Run afterInit hooks before starting any servers
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -178,12 +188,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
}
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
@@ -215,6 +219,11 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
}
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {
@@ -255,7 +264,23 @@ func (s *BaseServer) SetContainer(key string, container any) {
log.Tracef("container with key %s set successfully", key)
}
// SetHandlerFunc allows overriding the default HTTP handler function.
// This is useful for multiplexing additional services on the same port.
func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
s.container["customHandler"] = handler
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
log.Tracef("using custom handler")
return handler
}
}
// Use default handler
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

View File

@@ -13,7 +13,7 @@ import (
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a reverse proxy is created and must be used exactly once by the proxy
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
@@ -24,10 +24,10 @@ type OneTimeTokenStore struct {
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ReverseProxyID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
ServiceID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
@@ -48,10 +48,10 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
// GenerateToken creates a new cryptographically secure one-time token
// with the specified TTL. The token is associated with a specific
// accountID and reverseProxyID for validation purposes.
// accountID and serviceID for validation purposes.
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl time.Duration) (string, error) {
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
@@ -65,20 +65,20 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ReverseProxyID: reverseProxyID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
reverseProxyID, accountID, ttl)
serviceID, accountID, ttl)
return token, nil
}
// ValidateAndConsume verifies the token against the provided accountID and
// reverseProxyID, checks expiration, and then deletes it to enforce single-use.
// serviceID, checks expiration, and then deletes it to enforce single-use.
//
// This method uses constant-time comparison to prevent timing attacks.
//
@@ -87,14 +87,14 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl
// - Token has expired
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID string) error {
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
reverseProxyID, accountID)
serviceID, accountID)
return fmt.Errorf("invalid token")
}
@@ -102,7 +102,7 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
reverseProxyID, accountID)
serviceID, accountID)
return fmt.Errorf("token expired")
}
@@ -113,18 +113,18 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID
return fmt.Errorf("account ID mismatch")
}
// Validate reverse proxy ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ReverseProxyID), []byte(reverseProxyID)) != 1 {
log.Warnf("Token validation failed: reverse proxy ID mismatch (expected: %s, got: %s)",
metadata.ReverseProxyID, reverseProxyID)
return fmt.Errorf("reverse proxy ID mismatch")
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
log.Infof("Token validated and consumed for proxy %s in account %s",
reverseProxyID, accountID)
serviceID, accountID)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,234 @@
package grpc
import (
"context"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
lastUsedUpdateInterval = time.Minute
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
lastUsedCleanupInterval = 2 * time.Minute
)
type proxyTokenContextKey struct{}
// ProxyTokenContextKey is the typed key used to store validated token info in context.
var ProxyTokenContextKey = proxyTokenContextKey{}
// proxyTokenID identifies a proxy access token by its database ID.
type proxyTokenID = string
// proxyTokenStore defines the store interface needed for token validation
type proxyTokenStore interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
}
// proxyAuthInterceptor holds state for proxy authentication interceptors.
type proxyAuthInterceptor struct {
store proxyTokenStore
failureLimiter *authFailureLimiter
// lastUsedMu protects lastUsedTimes
lastUsedMu sync.Mutex
lastUsedTimes map[proxyTokenID]time.Time
cancel context.CancelFunc
}
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
ctx, cancel := context.WithCancel(context.Background())
i := &proxyAuthInterceptor{
store: tokenStore,
failureLimiter: newAuthFailureLimiter(),
lastUsedTimes: make(map[proxyTokenID]time.Time),
cancel: cancel,
}
go i.lastUsedCleanupLoop(ctx)
return i
}
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
// The returned close function must be called on shutdown to stop background goroutines.
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
interceptor := newProxyAuthInterceptor(tokenStore)
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(ctx, req)
}
token, err := interceptor.validateProxyToken(ctx)
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
return nil, err
}
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
return handler(ctx, req)
}
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(srv, ss)
}
token, err := interceptor.validateProxyToken(ss.Context())
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
return err
}
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
wrapped := &wrappedServerStream{
ServerStream: ss,
ctx: ctx,
}
return handler(srv, wrapped)
}
return unary, stream, interceptor.close
}
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
}
token, err := i.doValidateProxyToken(ctx)
if err != nil {
if clientIP != "" {
i.failureLimiter.recordFailure(clientIP)
}
return nil, err
}
i.maybeUpdateLastUsed(ctx, token.ID)
return token, nil
}
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
}
authValues := md.Get("authorization")
if len(authValues) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
}
authValue := authValues[0]
if !strings.HasPrefix(authValue, "Bearer ") {
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
}
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
if err := plainToken.Validate(); err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
}
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}
return token, nil
}
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
now := time.Now()
i.lastUsedMu.Lock()
lastUpdate, exists := i.lastUsedTimes[tokenID]
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
i.lastUsedMu.Unlock()
return
}
i.lastUsedTimes[tokenID] = now
i.lastUsedMu.Unlock()
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
}
}
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
ticker := time.NewTicker(lastUsedCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
i.cleanupStaleLastUsed()
case <-ctx.Done():
return
}
}
}
// cleanupStaleLastUsed removes entries older than 2x the update interval.
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
i.lastUsedMu.Lock()
defer i.lastUsedMu.Unlock()
now := time.Now()
staleThreshold := 2 * lastUsedUpdateInterval
for id, lastUpdate := range i.lastUsedTimes {
if now.Sub(lastUpdate) > staleThreshold {
delete(i.lastUsedTimes, id)
}
}
}
func (i *proxyAuthInterceptor) close() {
i.cancel()
i.failureLimiter.stop()
}
// GetProxyTokenFromContext retrieves the validated proxy token from the context
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
if !ok {
return nil
}
return token
}
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

View File

@@ -0,0 +1,134 @@
package grpc
import (
"context"
"net"
"sync"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"golang.org/x/time/rate"
"google.golang.org/grpc/peer"
)
const (
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
proxyAuthFailureBurst = 5
// proxyAuthLimiterCleanup is how often stale limiters are removed.
proxyAuthLimiterCleanup = 5 * time.Minute
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
proxyAuthLimiterTTL = 15 * time.Minute
)
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
// One token every 12 seconds = 5 per minute.
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
// clientIP identifies a client by its IP address for rate limiting purposes.
type clientIP = string
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
type authFailureLimiter struct {
mu sync.Mutex
limiters map[clientIP]*limiterEntry
failureRate rate.Limit
cancel context.CancelFunc
}
func newAuthFailureLimiter() *authFailureLimiter {
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
}
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
ctx, cancel := context.WithCancel(context.Background())
l := &authFailureLimiter{
limiters: make(map[clientIP]*limiterEntry),
failureRate: failureRate,
cancel: cancel,
}
go l.cleanupLoop(ctx)
return l
}
// isLimited returns true if the given IP has exhausted its failure budget.
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
l.mu.Lock()
defer l.mu.Unlock()
entry, exists := l.limiters[ip]
if !exists {
return false
}
return entry.limiter.Tokens() < 1
}
// recordFailure consumes a token from the rate limiter for the given IP.
func (l *authFailureLimiter) recordFailure(ip clientIP) {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
entry, exists := l.limiters[ip]
if !exists {
entry = &limiterEntry{
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
}
l.limiters[ip] = entry
}
entry.lastAccess = now
entry.limiter.Allow()
}
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(proxyAuthLimiterCleanup)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.cleanup()
case <-ctx.Done():
return
}
}
}
func (l *authFailureLimiter) cleanup() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for ip, entry := range l.limiters {
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
delete(l.limiters, ip)
}
}
}
func (l *authFailureLimiter) stop() {
l.cancel()
}
// peerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP {
if addr, ok := realip.FromContext(ctx); ok {
return addr.String()
}
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}
return ""
}

View File

@@ -0,0 +1,98 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -0,0 +1,381 @@
package grpc
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
err error
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type mockUsersManager struct {
users map[string]*types.User
err error
}
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[userID]
if !ok {
return nil, errors.New("user not found")
}
return user, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
users map[string]*types.User
proxyErr error
userErr error
expectErr bool
expectErrMsg string
}{
{
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
expectErr: true,
expectErrMsg: "user not found",
},
{
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
},
expectErr: true,
expectErrMsg: "not in allowed groups",
},
{
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
},
expectErr: false,
},
{
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
},
expectErr: false,
},
{
name: "proxy manager error",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: nil,
proxyErr: errors.New("database error"),
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account services",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
usersManager: &mockUsersManager{
users: tt.users,
err: tt.userErr,
},
}
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
err error
expectProxy bool
expectErr bool
}{
{
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
},
},
expectProxy: true,
expectErr: false,
},
{
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
expectErr: true,
},
{
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
expectProxy: false,
expectErr: true,
},
{
name: "manager error",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: nil,
err: errors.New("database error"),
expectProxy: false,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},
}
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)
assert.Nil(t, proxy)
} else {
require.NoError(t, err)
require.NotNil(t, proxy)
assert.Equal(t, tt.domain, proxy.Domain)
}
})
}
}

View File

@@ -0,0 +1,232 @@
package grpc
import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
select {
case msg := <-ch:
return msg
case <-time.After(time.Second):
return nil
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080/"},
},
}
s.SendServiceUpdateToCluster(update, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
// All tokens must be unique
tokenSet := make(map[string]struct{})
for i, tok := range tokens {
_, exists := tokenSet[tok]
assert.False(t, exists, "proxy %d got duplicate token", i)
tokenSet[tok] = struct{}{}
}
// Each token must be independently consumable
for i, tok := range tokens {
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
assert.NoError(t, err, "proxy %d token should validate successfully", i)
}
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(update, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
// Both tokens should validate
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
}
// generateState creates a state using the same format as GetOIDCURL.
func generateState(s *ProxyServiceServer, redirectURL string) string {
nonce := make([]byte, 16)
_, _ = rand.Read(nonce)
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
payload := redirectURL + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
}
func TestOAuthState_NeverTheSame(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
redirectURL := "https://app.example.com/callback"
// Generate 100 states for the same redirect URL
states := make(map[string]bool)
for i := 0; i < 100; i++ {
state := generateState(s, redirectURL)
// State must have 3 parts: base64(url)|nonce|hmac
parts := strings.Split(state, "|")
require.Equal(t, 3, len(parts), "state must have 3 parts")
// State must be unique
require.False(t, states[state], "state %d is a duplicate", i)
states[state] = true
}
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -77,8 +77,9 @@ type Server struct {
oAuthConfigProvider idp.OAuthConfigProvider
syncSem atomic.Int32
syncLim int32
syncSem atomic.Int32
syncLimEnabled bool
syncLim int32
}
// NewServer creates a new Management server
@@ -108,6 +109,7 @@ func NewServer(
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
syncLim := int32(defaultSyncLim)
syncLimEnabled := true
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
@@ -115,6 +117,9 @@ func NewServer(
} else {
//nolint:gosec
syncLim = int32(syncLimParsed)
if syncLim < 0 {
syncLimEnabled = false
}
}
}
@@ -134,7 +139,8 @@ func NewServer(
loginFilter: newLoginFilter(),
syncLim: syncLim,
syncLim: syncLim,
syncLimEnabled: syncLimEnabled,
}, nil
}
@@ -212,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncSem.Load() >= s.syncLim {
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
@@ -294,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
@@ -305,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
return err
}
@@ -313,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
return err
}
@@ -330,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
@@ -398,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
// It implements a backpressure mechanism that sends the first update immediately,
// then debounces subsequent rapid updates, ensuring only the latest update is sent
// after a quiet period.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
// Create a debouncer for this peer connection
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
defer debouncer.Stop()
for {
select {
// condition when there are some updates
// todo set the updates channel size to 1
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
@@ -410,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
}
// Timer expired - quiet period reached, send pending updates if any
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
continue
}
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
for _, pendingUpdate := range pendingUpdates {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return srv.Context().Err()
}
}
@@ -431,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed processing update message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.Send(&proto.EncryptedMessage{
@@ -448,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
@@ -480,11 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
return nil
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}

View File

@@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
@@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {

View File

@@ -0,0 +1,103 @@
package grpc
import (
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
)
// UpdateDebouncer implements a backpressure mechanism that:
// - Sends the first update immediately
// - Coalesces rapid subsequent network map updates (only latest matters)
// - Queues control/config updates (all must be delivered)
// - Preserves the order of messages (important for control configs between network maps)
// - Ensures pending updates are sent after a quiet period
type UpdateDebouncer struct {
debounceInterval time.Duration
timer *time.Timer
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
timerC <-chan time.Time
}
// NewUpdateDebouncer creates a new debouncer with the specified interval
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
return &UpdateDebouncer{
debounceInterval: interval,
}
}
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
if d.timer == nil {
// No active debounce timer, signal to send immediately
// and start the debounce period
d.startTimer()
return true
}
// Already in debounce period, accumulate this update preserving order
// Check if we should coalesce with the last pending update
if len(d.pendingUpdates) > 0 &&
update.MessageType == network_map.MessageTypeNetworkMap &&
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
// Replace the last network map with this one (coalesce consecutive network maps)
d.pendingUpdates[len(d.pendingUpdates)-1] = update
} else {
// Append to the queue (preserves order for control configs and non-consecutive network maps)
d.pendingUpdates = append(d.pendingUpdates, update)
}
d.resetTimer()
return false
}
// TimerChannel returns the timer channel for select statements
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
if d.timer == nil {
return nil
}
return d.timerC
}
// GetPendingUpdates returns and clears all pending updates after timer expiration.
// Updates are returned in the order they were received, with consecutive network maps
// already coalesced to only the latest one.
// If there were pending updates, it restarts the timer to continue debouncing.
// If there were no pending updates, it clears the timer (true quiet period).
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
updates := d.pendingUpdates
d.pendingUpdates = nil
if len(updates) > 0 {
// There were pending updates, so updates are still coming rapidly
// Restart the timer to continue debouncing mode
if d.timer != nil {
d.timer.Reset(d.debounceInterval)
}
} else {
// No pending updates means true quiet period - return to immediate mode
d.timer = nil
d.timerC = nil
}
return updates
}
// Stop stops the debouncer and cleans up resources
func (d *UpdateDebouncer) Stop() {
if d.timer != nil {
d.timer.Stop()
d.timer = nil
d.timerC = nil
}
d.pendingUpdates = nil
}
func (d *UpdateDebouncer) startTimer() {
d.timer = time.NewTimer(d.debounceInterval)
d.timerC = d.timer.C
}
func (d *UpdateDebouncer) resetTimer() {
d.timer.Stop()
d.timer.Reset(d.debounceInterval)
}

View File

@@ -0,0 +1,587 @@
package grpc
import (
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
shouldSend := debouncer.ProcessUpdate(update)
if !shouldSend {
t.Error("First update should be sent immediately")
}
if debouncer.TimerChannel() == nil {
t.Error("Timer should be started after first update")
}
}
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update should be sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Rapid subsequent updates should be coalesced
if debouncer.ProcessUpdate(update2) {
t.Error("Second rapid update should not be sent immediately")
}
if debouncer.ProcessUpdate(update3) {
t.Error("Third rapid update should not be sent immediately")
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Send second update within debounce period
debouncer.ProcessUpdate(update2)
// Wait for timer
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update2 {
t.Error("Should get the last update")
}
if pendingUpdates[0] == update1 {
t.Error("Should not get the first update")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Wait a bit, but not the full debounce period
time.Sleep(30 * time.Millisecond)
// Send second update - should reset timer
debouncer.ProcessUpdate(update2)
// Wait a bit more
time.Sleep(30 * time.Millisecond)
// Send third update - should reset timer again
debouncer.ProcessUpdate(update3)
// Now wait for the timer (should fire after last update's reset)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
// Timer should be restarted since there was a pending update
if debouncer.TimerChannel() == nil {
t.Error("Timer should be restarted after sending pending update")
}
case <-time.After(150 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
debouncer.ProcessUpdate(update1)
// Second update coalesced
debouncer.ProcessUpdate(update2)
// Wait for timer to expire
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have pending update")
}
// After sending pending update, timer is restarted, so next update is NOT immediate
if debouncer.ProcessUpdate(update3) {
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
}
// Wait for the restarted timer and verify update3 is pending
select {
case <-debouncer.TimerChannel():
finalUpdates := debouncer.GetPendingUpdates()
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
t.Error("Should get update3 as pending")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired for restarted timer")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send update to start timer
debouncer.ProcessUpdate(update)
// Stop should clean up
debouncer.Stop()
// Multiple stops should be safe
debouncer.Stop()
}
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate high-frequency updates
var lastUpdate *network_map.UpdateMessage
sentImmediately := 0
for i := 0; i < 100; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
lastUpdate = update
if debouncer.ProcessUpdate(update) {
sentImmediately++
}
time.Sleep(1 * time.Millisecond) // Very rapid updates
}
// Only first update should be sent immediately
if sentImmediately != 1 {
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != lastUpdate {
t.Error("Should get the very last update")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
// Wait for timer to expire with no additional updates (true quiet period)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates")
}
// After true quiet period, timer should be cleared
if debouncer.TimerChannel() != nil {
t.Error("Timer should be cleared after quiet period")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
updates := make([]*network_map.UpdateMessage, 5)
for i := range updates {
updates[i] = &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
}
// First update sent immediately
debouncer.ProcessUpdate(updates[0])
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
debouncer.ProcessUpdate(updates[1])
debouncer.ProcessUpdate(updates[2])
debouncer.ProcessUpdate(updates[3])
debouncer.ProcessUpdate(updates[4])
// Wait for debounce
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
}
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Wait for timer without sending any more updates (true quiet period)
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates during quiet period")
}
// After true quiet period, next update should be sent immediately
if !debouncer.ProcessUpdate(update2) {
t.Error("Update after true quiet period should be sent immediately")
}
}
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate continuous high-frequency updates
for i := 0; i < 10; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
if i == 0 {
// First one sent immediately
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
} else {
// All others should be coalesced (not sent immediately)
if debouncer.ProcessUpdate(update) {
t.Errorf("Update %d should not be sent immediately", i)
}
}
// Wait a bit but send next update before debounce expires
time.Sleep(20 * time.Millisecond)
}
// Now wait for final debounce
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have the last update pending")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
tokenUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate)
// Send multiple control config updates - they should all be queued
debouncer.ProcessUpdate(tokenUpdate1)
debouncer.ProcessUpdate(tokenUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get both control config updates
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
}
// Control configs should come first
if pendingUpdates[0] != tokenUpdate1 {
t.Error("First pending update should be tokenUpdate1")
}
if pendingUpdates[1] != tokenUpdate2 {
t.Error("Second pending update should be tokenUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
netmapUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate1)
// Send token update and network map update
debouncer.ProcessUpdate(tokenUpdate)
debouncer.ProcessUpdate(netmapUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get 2 updates in order: token, then network map
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
}
// Token update should come first (preserves order)
if pendingUpdates[0] != tokenUpdate {
t.Error("First pending update should be tokenUpdate")
}
// Network map update should come second
if pendingUpdates[1] != netmapUpdate2 {
t.Error("Second pending update should be netmapUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate: 50 network maps -> 1 control config -> 50 network maps
// Expected result: 3 messages (netmap, controlConfig, netmap)
// Send first network map immediately
firstNetmap := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
MessageType: network_map.MessageTypeNetworkMap,
}
if !debouncer.ProcessUpdate(firstNetmap) {
t.Error("First update should be sent immediately")
}
// Send 49 more network maps (will be coalesced to last one)
var lastNetmapBatch1 *network_map.UpdateMessage
for i := 1; i < 50; i++ {
lastNetmapBatch1 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch1)
}
// Send 1 control config
controlConfig := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
debouncer.ProcessUpdate(controlConfig)
// Send 50 more network maps (will be coalesced to last one)
var lastNetmapBatch2 *network_map.UpdateMessage
for i := 50; i < 100; i++ {
lastNetmapBatch2 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch2)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get exactly 3 updates: netmap, controlConfig, netmap
if len(pendingUpdates) != 3 {
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
}
// First should be the last netmap from batch 1
if pendingUpdates[0] != lastNetmapBatch1 {
t.Error("First pending update should be last netmap from batch 1")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
// Second should be the control config
if pendingUpdates[1] != controlConfig {
t.Error("Second pending update should be control config")
}
// Third should be the last netmap from batch 2
if pendingUpdates[2] != lastNetmapBatch2 {
t.Error("Third pending update should be last netmap from batch 2")
}
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}

View File

@@ -0,0 +1,304 @@
//go:build integration
package grpc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type validateSessionTestSetup struct {
proxyService *ProxyServiceServer
store store.Store
cleanup func()
}
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
createTestProxies(t, ctx, testStore)
return &validateSessionTestSetup{
proxyService: proxyService,
store: testStore,
cleanup: storeCleanup,
}
}
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
}
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "restricted-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User in different account should be denied")
assert.Equal(t, "account_mismatch", resp.DeniedReason)
}
func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Non-existent user should be denied")
assert.Equal(t, "user_not_found", resp.DeniedReason)
}
func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "unknown-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: "invalid-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Invalid token should be denied")
assert.Equal(t, "invalid_token", resp.DeniedReason)
}
func TestValidateSession_MissingDomain(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
SessionToken: "some-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
func TestValidateSession_MissingToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}

View File

@@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -26,7 +27,6 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -49,6 +49,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -82,8 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
proxyController port_forwarding.Controller
settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
// config contains the management server configuration
config *nbconfig.Config
@@ -113,6 +115,10 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
@@ -321,6 +327,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
updateAccountPeers = true
}
@@ -795,6 +804,19 @@ func IsEmbeddedIdp(i idp.Manager) bool {
return ok
}
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
// Returns true only when using embedded IDP with local auth disabled in config.
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
if isNil(i) {
return false
}
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
if !ok {
return false
}
return embeddedIdp.IsLocalAuthDisabled()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
@@ -1657,13 +1679,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1671,8 +1693,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -58,7 +59,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
@@ -114,8 +115,8 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@@ -139,4 +140,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager)
}

View File

@@ -27,6 +27,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
@@ -1800,6 +1802,14 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
Services: []*reverseproxy.Service{
{
ID: "service1",
Name: "test-service",
AccountID: "account1",
Targets: []*reverseproxy.Target{},
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
@@ -1881,7 +1891,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1952,7 +1962,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1961,6 +1971,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := time.Now().UTC()
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
})
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
})
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -1983,7 +2069,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3036,6 +3122,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
return manager, updateManager, nil
}
@@ -3176,7 +3264,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
assert.NoError(b, err)
}

View File

@@ -204,9 +204,9 @@ const (
UserInviteLinkRegenerated Activity = 106
UserInviteLinkDeleted Activity = 107
ReverseProxyCreated Activity = 108
ReverseProxyUpdated Activity = 109
ReverseProxyDeleted Activity = 110
ServiceCreated Activity = 108
ServiceUpdated Activity = 109
ServiceDeleted Activity = 110
AccountDeleted Activity = 99999
)
@@ -342,9 +342,9 @@ var activityMap = map[Activity]Code{
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
ReverseProxyCreated: {"Reverse proxy created", "reverseproxy.create"},
ReverseProxyUpdated: {"Reverse proxy updated", "reverseproxy.update"},
ReverseProxyDeleted: {"Reverse proxy deleted", "reverseproxy.delete"},
ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

Some files were not shown because too many files have changed in this diff Show More