Compare commits

..

53 Commits

Author SHA1 Message Date
bcmmbaga
72513d7522 Skip network map calculation when client serial matches current
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-01-27 23:03:43 +03:00
Zoltan Papp
a1f1bf1f19 Merge branch 'main' into feat/network-map-serial 2025-12-18 15:59:53 +01:00
Zoltan Papp
b5dec3df39 Track network serial in engine 2025-12-18 15:27:49 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
Maycon Santos
cb6b086164 [client] Reorder subsystem shutdown so peer removal goes first (#4914)
Remove peers before DNS and routes
2025-12-04 21:01:22 +01:00
Zoltan Papp
71b6855e09 [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891)
* Fix engine shutdown deadlock and message handling races

- Release syncMsgMux before waiting for shutdownWg to prevent deadlock
- Check context inside lock in handleSync and receiveSignalEvents
- Prevents nil pointer access when messages arrive during engine stop
2025-12-04 19:51:50 +01:00
Viktor Liu
9bdc4908fb [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) 2025-12-04 19:16:38 +01:00
Bethuel Mmbaga
031ab11178 [client] Remove select account prompt (#4912)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-12-04 14:57:29 +01:00
Zoltan Papp
d2e48d4f5e [relay] Use instanceURL instead of Exposed address. (#4905)
Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields.
2025-12-03 18:42:53 +01:00
Bethuel Mmbaga
27dd97c9c4 [management] Add support to disable geolocation service (#4901) 2025-12-03 14:45:59 +03:00
Maycon Santos
e87b4ace11 [client] Add sleep state tracking to handle wakeup/sleep events reliably (#4894)
Adds a new NotifyOSLifecycle RPC and server handler to centralize OS sleep/wake handling, introduces Server.sleepTriggeredDown for coordination, updates client UI to call the new RPC, and adjusts the internal sleep event enum zero-value semantics.
2025-12-03 11:53:39 +01:00
Pascal Fischer
a232cf614c [management] record pat usage metrics (#4888) 2025-12-02 18:31:59 +01:00
Maycon Santos
a293f760af [client] Add conditional peer removal logic during shutdown (#4897) 2025-12-02 16:30:15 +01:00
Pascal Fischer
10e9cf8c62 [management] update management integrations (#4895) 2025-12-02 14:13:01 +01:00
Pascal Fischer
7193bd2da7 [management] Refactor network map controller (#4789) 2025-12-02 12:34:28 +01:00
Bethuel Mmbaga
52948ccd61 [management] Add user created activity event (#4893) 2025-12-02 14:17:59 +03:00
Fahri Shihab
4b77359042 [management] Groups API with name query parameter (#4831) 2025-12-01 16:57:42 +01:00
Zoltan Papp
387d43bcc1 [client, management] Add OAuth select_account prompt support to PKCE flow (#4880)
* Add OAuth select_account prompt support to PKCE flow

Extends LoginFlag enum with select_account options to enable
multi-account selection during authentication. This allows users
to choose which account to use when multiple accounts have active
sessions with the identity provider.

The new flags are backward compatible - existing LoginFlag values
(0=prompt login, 1=max_age=0) retain their original behavior.
2025-12-01 14:25:52 +01:00
Zoltan Papp
e47d815dd2 Fix IsAnotherProcessRunning (#4858)
Compare the exact process name rather than searching for a substring of the full path
2025-12-01 14:16:03 +01:00
shuuri-labs
cb83b7c0d3 [relay] use exposed address for healthcheck TLS validation (#4872)
* fix(relay): use exposed address for healthcheck TLS validation

Healthcheck was using listen address (0.0.0.0) instead of exposed address
(domain name) for certificate validation, causing validation to always fail.

Now correctly uses the exposed address where the TLS certificate is valid,
matching real client connection behavior.

* - store exposedAddress directly in Relay struct instead of parsing on every call
- remove unused parseHostPort() function
- remove unused ListenAddress() method from ServiceChecker interface
- improve error logging with address context

* [relay/healthcheck] Remove QUIC health check logic, update WebSocket validation flow

Refactored health check logic by removing QUIC-specific connection validation and simplifying logic for WebSocket protocol. Adjusted certificate validation flow and improved handling of exposed addresses.

* [relay/healthcheck] Fix certificate validation status during health check

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-11-28 21:53:53 +01:00
Zoltan Papp
ddcd182859 [client] Sleep detection on macOS (#4859)
A macOS-specific sleep detection mechanism using IOKit and CoreFoundation via cgo is introduced, with a fallback implementation for unsupported platforms. A public Service wrapper provides an event-driven API translating system sleep/wake events into gRPC calls. The UI client integrates sleep detection to manage connectivity state based on system sleep status.
2025-11-28 17:26:22 +01:00
Maycon Santos
aca0398105 [client] Add excluded port range handling for PKCE flow (#4853) 2025-11-26 16:07:45 +01:00
Viktor Liu
02200d790b [client] Open browser for ssh automatically (#4838) 2025-11-26 16:06:47 +01:00
Bethuel Mmbaga
f31bba87b4 [management] Preserve validator settings on account settings update (#4862) 2025-11-26 17:07:44 +03:00
shuuri-labs
7285fef0f0 feat: Add support for displaying device code (UserCode) on Android TV SSO flow (#4800)
- Modified URLOpener interface to pass userCode alongside URL in login.go
- added ability to force device auth flow
2025-11-25 15:51:16 +01:00
Maycon Santos
20973063d8 [client] Support disable search domain for custom zones (#4826)
Two new boolean flags, SearchDomainDisabled and SkipPTRProcess, are added to CustomZone and its protobuf; they are propagated through the engine to DNS host logic. Host matching now uses SearchDomainDisabled directly, and PTR collection skips zones with SkipPTRProcess; reverse zones are initialized with SearchDomainDisabled: true.
2025-11-24 17:50:08 +01:00
Aziz Hasanain
ba2e9b6d88 [management] Fix SSH JWT issuer derivation for IDPs with path components (#4844) 2025-11-24 12:12:51 +01:00
Viktor Liu
131d7a3694 [client] Make mss clamping optional for nftables (#4843) 2025-11-22 18:57:07 +01:00
Maycon Santos
290fe2d8b9 [client/management/signal/relay] Update go.mod to use Go 1.24.10 and upgrade x/crypto dependencies (#4828)
Upgrade Go toolchain and golang.org/x/* deps to 1.24.10, standardize GitHub Actions to derive Go version from go.mod and adjust checkout ordering, raise WASM size limit to 55 MB, update FreeBSD tarball and gomobile refs, fix a few format-string/logging calls, treat usernames ending with $ as system accounts, and add Windows tests.
2025-11-22 10:10:18 +01:00
Vlad
7fb1a2fe31 [management] removed TestBufferUpdateAccountPeers because it was incorrect (#4839) 2025-11-22 01:23:33 +01:00
Diego Romar
32146e576d [android] allow selection/deselection of network resources on android peers (#4607) 2025-11-21 13:36:33 +01:00
Viktor Liu
1311364397 [client] Increase ssh detection timeout (#4827) 2025-11-20 17:09:22 +01:00
Maycon Santos
68f56b797d [management] Add native ssh port rule on 22 (#4810)
Implements feature-aware firewall rule expansion: derives peer-supported features (native SSH, portRanges) from peer version, prefers explicit Ports over PortRanges when expanding, conditionally appends a native SSH (22022) rule when policy and peer support allow, and adds helpers plus tests for SSH expansion behavior.
2025-11-19 13:16:47 +01:00
Pascal Fischer
3351b38434 [management] pass config to controller (#4807) 2025-11-19 11:52:18 +01:00
Pascal Fischer
05cbead39b [management] Fix direct peer networks route (#4802) 2025-11-18 17:15:57 +01:00
Viktor Liu
60f4d5f9b0 [client] Revert migrate deprecated grpc client code #4805 2025-11-18 12:41:17 +01:00
Vlad
4eeb2d8deb [management] added exception on not appending route firewall rules if we have all wildcard (#4801) 2025-11-17 18:20:30 +01:00
Viktor Liu
d71a82769c [client,management] Rewrite the SSH feature (#4015) 2025-11-17 17:10:41 +01:00
Misha Bragin
0d79301141 Update client login success page (#4797) 2025-11-17 15:28:20 +01:00
Hakan Sariman
20f5f00635 [client] Add unit tests for engine synchronization and Info flag copying
- Introduced tests for the Engine's handleSync method to verify behavior when SkipNetworkMapUpdate is true and when NetworkMap is nil.
- Added a test for the Info struct to ensure correct copying of flag values from one instance to another, while preserving unrelated fields.
2025-10-17 10:03:07 +03:00
Hakan Sariman
fc141cf3a3 [client] Refactor lastNetworkMapSerial handling in GrpcClient
- Removed atomic operations for lastNetworkMapSerial and replaced them with mutex-based methods for thread-safe access.
2025-09-29 18:49:23 +07:00
Hakan Sariman
d0c65fa08e [client] Add skipNetworkMapUpdate field to SyncResponse for conditional updates 2025-09-29 18:28:14 +07:00
Hakan Sariman
f241bfa339 Refactor flag setting in Info struct to use CopyFlagsFrom method 2025-09-29 15:38:35 +07:00
Hakan Sariman
4b2cd97d5f [client] Enhance SyncRequest with NetworkMap serial tracking
- Added `networkMapSerial` field to `SyncRequest` for tracking the last known network map serial number.
- Updated `GrpcClient` to store and utilize the last network map serial during sync operations, optimizing synchronization processes.
- Improved handling of system info updates to ensure accurate metadata is sent with sync requests.
2025-09-25 19:28:35 +07:00
288 changed files with 24154 additions and 4388 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -19,35 +19,37 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ ! -z "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo "✓ No problematic dependencies found"
echo ""
echo "✅ All internal license dependencies are clean"
fi
done
if [ $FOUND_ISSUES -eq 1 ]; then
echo ""
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses

View File

@@ -15,13 +15,14 @@ jobs:
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v4

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -30,7 +30,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Get Go environment
@@ -106,15 +106,15 @@ jobs:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -151,15 +151,15 @@ jobs:
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.23-alpine \
golang:1.24-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -220,15 +220,15 @@ jobs:
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -270,15 +270,15 @@ jobs:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -321,15 +321,15 @@ jobs:
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -408,15 +408,16 @@ jobs:
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -497,15 +498,15 @@ jobs:
-p 9090:9090 \
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -561,15 +562,15 @@ jobs:
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -24,7 +24,7 @@ jobs:
uses: actions/setup-go@v5
id: go
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Get Go environment

View File

@@ -46,7 +46,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
@@ -39,7 +39,7 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
- name: gomobile init
run: gomobile init
- name: build android netbird lib
@@ -56,9 +56,9 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib

View File

@@ -20,7 +20,7 @@ concurrency:
jobs:
release:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest-m
env:
flags: ""
steps:
@@ -40,7 +40,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -136,7 +136,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -200,7 +200,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4

View File

@@ -67,10 +67,13 @@ jobs:
- name: Install curl
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@v4
@@ -80,9 +83,6 @@ jobs:
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Setup MySQL privileges
if: matrix.store == 'mysql'
run: |

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
@@ -45,7 +45,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
@@ -60,8 +60,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
exit 1
fi

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy
```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -4,10 +4,13 @@ package android
import (
"context"
"fmt"
"os"
"slices"
"sync"
"golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
@@ -16,10 +19,13 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile
@@ -62,17 +68,18 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
cfgFile: platformFiles.ConfigurationFilePath(),
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -80,11 +87,12 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
@@ -107,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.login(urlOpener)
err = auth.login(urlOpener, isAndroidTV)
if err != nil {
return err
}
@@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
// Stop the internal client and free the resources
@@ -156,6 +164,19 @@ func (c *Client) Stop() {
c.ctxCancel()
}
func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
return e.RenewTun(fd)
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray {
p.IP,
p.FQDN,
p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi
}
@@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray {
return nil
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
log.Error("could not get route selector")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
}
networkArray.Add(network)
}
@@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func (c *Client) toggleRoute(command routeCommand) error {
return command.toggleRoute()
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}
engine := client.Engine()
if engine == nil {
return nil, fmt.Errorf("engine is not running")
}
manager := engine.GetRouteManager()
if manager == nil {
return nil, fmt.Errorf("could not get route manager")
}
return manager, nil
}
func (c *Client) SelectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
}
func (c *Client) DeselectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
}
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
// with its resolved IP addresses from the provided resolvedDomains map.
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
domains := NetworkDomains{}
for _, d := range route.Domains {
networkDomain := NetworkDomain{
Address: d.SafeString(),
}
if info, exists := resolvedDomains[d]; exists {
for _, prefix := range info.Prefixes {
networkDomain.addResolvedIP(prefix.Addr().String())
}
}
domains.Add(&networkDomain)
}
return domains
}
func exportEnvList(list *EnvList) {
if list == nil {
return

View File

@@ -32,7 +32,7 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
Open(url string, userCode string)
OnLoginSuccess()
}
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
}
// Login try register the client on the server
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
go func() {
err := a.login(urlOpener)
err := a.login(urlOpener, isAndroidTV)
if err != nil {
resultListener.OnError(err)
} else {
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
}()
}
func (a *Auth) login(urlOpener URLOpener) error {
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
// check if we need to generate JWT token
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
if err != nil {
return nil, err
}
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete)
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)

View File

@@ -0,0 +1,56 @@
//go:build android
package android
import "fmt"
type ResolvedIPs struct {
resolvedIPs []string
}
func (r *ResolvedIPs) Add(ipAddress string) {
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
}
func (r *ResolvedIPs) Get(i int) (string, error) {
if i < 0 || i >= len(r.resolvedIPs) {
return "", fmt.Errorf("%d is out of range", i)
}
return r.resolvedIPs[i], nil
}
func (r *ResolvedIPs) Size() int {
return len(r.resolvedIPs)
}
type NetworkDomain struct {
Address string
resolvedIPs ResolvedIPs
}
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
d.resolvedIPs.Add(resolvedIP)
}
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
return &d.resolvedIPs
}
type NetworkDomains struct {
domains []*NetworkDomain
}
func (n *NetworkDomains) Add(domain *NetworkDomain) {
n.domains = append(n.domains, domain)
}
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
if i < 0 || i >= len(n.domains) {
return nil, fmt.Errorf("%d is out of range", i)
}
return n.domains[i], nil
}
func (n *NetworkDomains) Size() int {
return len(n.domains)
}

View File

@@ -3,10 +3,16 @@
package android
type Network struct {
Name string
Network string
Peer string
Status string
Name string
Network string
Peer string
Status string
IsSelected bool
Domains NetworkDomains
}
func (n Network) GetNetworkDomains() *NetworkDomains {
return &n.Domains
}
type NetworkArray struct {

View File

@@ -1,3 +1,5 @@
//go:build android
package android
// PeerInfo describe information about the peers. It designed for the UI usage
@@ -5,6 +7,11 @@ type PeerInfo struct {
IP string
FQDN string
ConnStatus string // Todo replace to enum
Routes PeerRoutes
}
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
return &p.Routes
}
// PeerInfoArray is a wrapper of []PeerInfo

View File

@@ -0,0 +1,20 @@
//go:build android
package android
import "fmt"
type PeerRoutes struct {
routes []string
}
func (p *PeerRoutes) Get(i int) (string, error) {
if i < 0 || i >= len(p.routes) {
return "", fmt.Errorf("%d is out of range", i)
}
return p.routes[i], nil
}
func (p *PeerRoutes) Size() int {
return len(p.routes)
}

View File

@@ -0,0 +1,10 @@
//go:build android
package android
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
// at their default locations due to android OS restrictions.
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
}

View File

@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetEnableSSHRoot reads SSH root login setting from config file
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
if p.configInput.EnableSSHRoot != nil {
return *p.configInput.EnableSSHRoot, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHRoot == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHRoot, err
}
// SetEnableSSHRoot stores the given value and waits for commit
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
p.configInput.EnableSSHRoot = &enabled
}
// GetEnableSSHSFTP reads SSH SFTP setting from config file
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
if p.configInput.EnableSSHSFTP != nil {
return *p.configInput.EnableSSHSFTP, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHSFTP == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHSFTP, err
}
// SetEnableSSHSFTP stores the given value and waits for commit
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
p.configInput.EnableSSHSFTP = &enabled
}
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
if p.configInput.EnableSSHLocalPortForwarding != nil {
return *p.configInput.EnableSSHLocalPortForwarding, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHLocalPortForwarding == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHLocalPortForwarding, err
}
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
p.configInput.EnableSSHLocalPortForwarding = &enabled
}
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
if p.configInput.EnableSSHRemotePortForwarding != nil {
return *p.configInput.EnableSSHRemotePortForwarding, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHRemotePortForwarding == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHRemotePortForwarding, err
}
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
p.configInput.EnableSSHRemotePortForwarding = &enabled
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {

View File

@@ -0,0 +1,67 @@
//go:build android
package android
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/route"
)
func executeRouteToggle(id string, manager routemanager.Manager,
operationName string,
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
netID := route.NetID(id)
routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err)
}
manager.TriggerSelection(manager.GetClientRoutes())
return nil
}
type routeCommand interface {
toggleRoute() error
}
type selectRouteCommand struct {
route string
manager routemanager.Manager
}
func (s selectRouteCommand) toggleRoute() error {
routeSelector := s.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
return routeSelector.SelectRoutes(routes, true, allRoutes)
}
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
}
type deselectRouteCommand struct {
route string
manager routemanager.Manager
}
func (d deselectRouteCommand) toggleRoute() error {
routeSelector := d.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
}

View File

@@ -4,14 +4,12 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
if err != nil {
return nil, err
}
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := openBrowser(verificationURIComplete); err != nil {
if err := util.OpenBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -35,7 +35,6 @@ const (
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection"
@@ -64,7 +63,6 @@ var (
customDNSAddress string
rosenpassEnabled bool
rosenpassPermissive bool
serverSSHAllowed bool
interfaceName string
wireguardPort uint16
networkMonitor bool
@@ -176,7 +174,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")

View File

@@ -3,125 +3,849 @@ package cmd
import (
"context"
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
var (
port int
userName = "root"
host string
const (
sshUsernameDesc = "SSH username"
hostArgumentRequired = "host argument required"
serverSSHAllowedFlag = "allow-server-ssh"
enableSSHRootFlag = "enable-ssh-root"
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
var (
port int
username string
host string
command string
localForwards []string
remoteForwards []string
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
requestPTY bool
sshNoBrowser bool
)
split := strings.Split(args[0], "@")
if len(split) == 2 {
userName = split[0]
host = split[1]
} else {
host = args[0]
}
var (
serverSSHAllowed bool
enableSSHRoot bool
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
)
return nil
},
Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
func init() {
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
cmd.SetOut(cmd.OutOrStdout())
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
}
ctx := internal.CtxInitState(cmd.Context())
sm := profilemanager.NewServiceManager(configPath)
activeProf, err := sm.GetActiveProfileState()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
profPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile path: %v", err)
}
config, err := profilemanager.ReadConfig(profPath)
if err != nil {
return fmt.Errorf("read profile config: %v", err)
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
os.Exit(1)
}
cancel()
}()
select {
case <-sig:
cancel()
case <-sshctx.Done():
}
return nil
},
sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
var sshCmd = &cobra.Command{
Use: "ssh [flags] [user@]host [command]",
Short: "Connect to a NetBird peer via SSH",
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
Port Forwarding:
-L [bind_address:]port:host:hostport Local port forwarding
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
-R [bind_address:]port:host:hostport Remote port forwarding
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
SSH Options:
-p, --port int Remote SSH port (default 22)
-u, --user string SSH username
--login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file
Examples:
netbird ssh peer-hostname
netbird ssh root@peer-hostname
netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true,
Args: validateSSHArgsWithoutFlagParsing,
RunE: sshFn,
Aliases: []string{"ssh"},
}
func sshFn(cmd *cobra.Command, args []string) error {
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return cmd.Help()
}
}
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err
}
cancel()
}()
err = c.OpenTerminal()
if err != nil {
select {
case <-sig:
cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done():
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
func getEnvOrDefault(flagName, defaultValue string) string {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
return envValue
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
return envValue
}
return defaultValue
}
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() {
port = sshserver.DefaultSSHPort
username = ""
host = ""
command = ""
localForwards = nil
remoteForwards = nil
strictHostKeyChecking = true
knownHostsFile = ""
identityFile = ""
sshNoBrowser = false
}
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
var localForwardFlags []string
var remoteForwardFlags []string
var filteredArgs []string
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case strings.HasPrefix(arg, "-L"):
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
case strings.HasPrefix(arg, "-R"):
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
default:
filteredArgs = append(filteredArgs, arg)
}
}
return filteredArgs, localForwardFlags, remoteForwardFlags
}
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
if arg == "-L" || arg == "-R" {
if i+1 < len(args) {
flags = append(flags, args[i+1])
i++
}
} else if len(arg) > 2 {
flags = append(flags, arg[2:])
}
return flags, i
}
// extractGlobalFlags parses global flags that were passed before 'ssh' command
func extractGlobalFlags(args []string) {
sshPos := findSSHCommandPosition(args)
if sshPos == -1 {
return
}
globalArgs := args[:sshPos]
parseGlobalArgs(globalArgs)
}
// findSSHCommandPosition locates the 'ssh' command in the argument list
func findSSHCommandPosition(args []string) int {
for i, arg := range args {
if arg == "ssh" {
return i
}
}
return -1
}
const (
configFlag = "config"
logLevelFlag = "log-level"
logFileFlag = "log-file"
)
// parseGlobalArgs processes the global arguments and sets the corresponding variables
func parseGlobalArgs(globalArgs []string) {
flagHandlers := map[string]func(string){
configFlag: func(value string) { configPath = value },
logLevelFlag: func(value string) { logLevel = value },
logFileFlag: func(value string) {
if !slices.Contains(logFiles, value) {
logFiles = append(logFiles, value)
}
},
}
shortFlags := map[string]string{
"c": configFlag,
"l": logLevelFlag,
}
for i := 0; i < len(globalArgs); i++ {
arg := globalArgs[i]
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
i = nextIndex
}
}
}
// parseFlag handles generic flag parsing for both long and short forms
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex
}
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex + 1
}
return false, currentIndex
}
type parsedFlag struct {
flagName string
value string
}
// parseEqualsFormat handles --flag=value and -f=value formats
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if !strings.Contains(arg, "=") {
return parsedFlag{}, false
}
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return parsedFlag{}, false
}
if strings.HasPrefix(parts[0], "--") {
flagName := strings.TrimPrefix(parts[0], "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: parts[1]}, true
}
}
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
shortFlag := strings.TrimPrefix(parts[0], "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: parts[1]}, true
}
}
}
return parsedFlag{}, false
}
// parseSpacedFormat handles --flag value and -f value formats
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if currentIndex+1 >= len(args) {
return parsedFlag{}, false
}
if strings.HasPrefix(arg, "--") {
flagName := strings.TrimPrefix(arg, "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
}
}
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
shortFlag := strings.TrimPrefix(arg, "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
}
}
}
return parsedFlag{}, false
}
// createSSHFlagSet creates and configures the flag set for SSH command parsing
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
RequestPTY bool
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
NoBrowser bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags
}
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
resetSSHGlobals()
if len(os.Args) > 2 {
extractGlobalFlags(os.Args[1:])
}
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err
}
remaining := fs.Args()
if len(remaining) < 1 {
return errors.New(hostArgumentRequired)
}
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
sshNoBrowser = flags.NoBrowser
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}
localForwards = localForwardFlags
remoteForwards = remoteForwardFlags
return parseHostnameAndCommand(remaining)
}
func parseHostnameAndCommand(args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
arg := args[0]
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if username == "" {
username = parts[0]
}
host = parts[1]
} else {
host = arg
}
if username == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
username = sudoUser
} else if currentUser, err := user.Current(); err == nil {
username = currentUser.Username
} else {
username = "root"
}
}
// Everything after hostname becomes the command
if len(args) > 1 {
command = strings.Join(args[1:], " ")
}
return nil
}
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
NoBrowser: sshNoBrowser,
})
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
return fmt.Errorf("dial %s: %w", target, err)
}
sshCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-sshCtx.Done()
if err := c.Close(); err != nil {
cmd.Printf("Error closing SSH connection: %v\n", err)
}
}()
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
return fmt.Errorf("start port forwarding: %w", err)
}
if command != "" {
return executeSSHCommand(sshCtx, c, command)
}
return openSSHTerminal(sshCtx, c)
}
// executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err)
}
return nil
}
// openSSHTerminal opens an interactive SSH terminal.
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if err := c.OpenTerminal(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err)
}
return nil
}
// startPortForwarding starts local and remote port forwarding based on command line flags
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
for _, forward := range localForwards {
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("local port forward %s: %w", forward, err)
}
}
for _, forward := range remoteForwards {
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("remote port forward %s: %w", forward, err)
}
}
return nil
}
// parseAndStartLocalForward parses and starts a local port forward (-L)
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Local port forward error: %v\n", err)
}
}()
return nil
}
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Remote port forward error: %v\n", err)
}
}()
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {
// Support formats:
// port:host:hostport -> localhost:port -> host:hostport
// host:port:host:hostport -> host:port -> host:hostport
// [host]:port:host:hostport -> [host]:port -> host:hostport
// port:unix_socket_path -> localhost:port -> unix_socket_path
// host:port:unix_socket_path -> host:port -> unix_socket_path
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
return parseIPv6ForwardSpec(spec)
}
parts := strings.Split(spec, ":")
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
}
switch len(parts) {
case 2:
return parseTwoPartForwardSpec(parts, spec)
case 3:
return parseThreePartForwardSpec(parts)
case 4:
return parseFourPartForwardSpec(parts)
default:
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
}
}
// parseTwoPartForwardSpec handles "port:unix_socket" format.
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
if isUnixSocket(parts[1]) {
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1]
return localAddr, remoteAddr, nil
}
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
}
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
func parseThreePartForwardSpec(parts []string) (string, string, error) {
if isUnixSocket(parts[2]) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2]
return localAddr, remoteAddr, nil
}
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
func parseFourPartForwardSpec(parts []string) (string, string, error) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2] + ":" + parts[3]
return localAddr, remoteAddr, nil
}
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
func parseIPv6ForwardSpec(spec string) (string, string, error) {
idx := strings.Index(spec, "]:")
if idx == -1 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
}
ipv6Host := spec[:idx+1]
remaining := spec[idx+2:]
parts := strings.Split(remaining, ":")
if len(parts) != 3 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
}
localAddr := ipv6Host + ":" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// isUnixSocket checks if a path is a Unix socket path.
func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
}
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
func normalizeLocalHost(host string) string {
if host == "*" {
return "0.0.0.0"
}
return host
}
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
// where command-line flags cannot be passed. Default is to open browser.
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(detectLogLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
log.Debugf("invalid port %q: %v", portStr, err)
os.Exit(detection.ServerTypeRegular.ExitCode())
}
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
log.Debugf("SSH server detection failed: %v", err)
cancel()
os.Exit(detection.ServerTypeRegular.ExitCode())
}
cancel()
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -0,0 +1,74 @@
//go:build unix
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sshExecUID uint32
sshExecGID uint32
sshExecGroups []uint
sshExecWorkingDir string
sshExecShell string
sshExecCommand string
sshExecPTY bool
)
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
var sshExecCmd = &cobra.Command{
Use: "exec",
Short: "Internal SSH execution with privilege dropping (hidden)",
Hidden: true,
RunE: runSSHExec,
}
func init() {
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
os.Exit(1)
}
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
os.Exit(1)
}
sshCmd.AddCommand(sshExecCmd)
}
// runSSHExec handles the SSH exec subcommand execution.
func runSSHExec(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sshExecGroups {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sshExecUID,
GID: sshExecGID,
Groups: groups,
WorkingDir: sshExecWorkingDir,
Shell: sshExecShell,
Command: sshExecCommand,
PTY: sshExecPTY,
}
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build unix
package cmd
import (
"errors"
"io"
"os"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpUID uint32
sftpGID uint32
sftpGroupsInt []uint
sftpWorkingDir string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with privilege dropping (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sftpGroupsInt {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sftpUID,
GID: sftpGID,
Groups: groups,
WorkingDir: sftpWorkingDir,
Shell: "",
Command: "",
}
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
cmd.PrintErrf("privilege drop failed: %v\n", err)
os.Exit(sshserver.ExitCodePrivilegeDropFail)
}
if config.WorkingDir != "" {
if err := os.Chdir(config.WorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Tracef("starting SFTP server with dropped privileges")
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeShellExecFail)
}
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeSuccess)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build windows
package cmd
import (
"errors"
"fmt"
"io"
"os"
"os/user"
"strings"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpWorkingDir string
windowsUsername string
windowsDomain string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with user switching for Windows (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
return sftpMainDirect(cmd)
}
func sftpMainDirect(cmd *cobra.Command) error {
currentUser, err := user.Current()
if err != nil {
cmd.PrintErrf("failed to get current user: %v\n", err)
os.Exit(sshserver.ExitCodeValidationFail)
}
if windowsUsername != "" {
expectedUsername := windowsUsername
if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
}
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail)
}
}
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
if sftpWorkingDir != "" {
if err := os.Chdir(sftpWorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Debugf("starting SFTP server")
exitCode := sshserver.ExitCodeSuccess
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
exitCode = sshserver.ExitCodeShellExecFail
}
if err := sftpServer.Close(); err != nil {
log.Debugf("SFTP server close error: %v", err)
}
os.Exit(exitCode)
return nil
}

717
client/cmd/ssh_test.go Normal file
View File

@@ -0,0 +1,717 @@
package cmd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSHCommand_FlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedUser string
expectedPort int
expectedCmd string
expectError bool
}{
{
name: "basic host",
args: []string{"hostname"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "",
},
{
name: "user@host format",
args: []string{"user@hostname"},
expectedHost: "hostname",
expectedUser: "user",
expectedPort: 22,
expectedCmd: "",
},
{
name: "host with command",
args: []string{"hostname", "echo", "hello"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "echo hello",
},
{
name: "command with flags should be preserved",
args: []string{"hostname", "ls", "-la", "/tmp"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "ls -la /tmp",
},
{
name: "double dash separator",
args: []string{"hostname", "--", "ls", "-la"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "-- ls -la",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
// Mock command for testing
cmd := sshCmd
cmd.SetArgs(tt.args)
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
if tt.expectedUser != "" {
assert.Equal(t, tt.expectedUser, username, "username mismatch")
}
assert.Equal(t, tt.expectedPort, port, "port mismatch")
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
})
}
}
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
// Test that SSH flags don't conflict with command flags
tests := []struct {
name string
args []string
expectedCmd string
description string
}{
{
name: "ls with -la flags",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
description: "ls flags should be passed to remote command",
},
{
name: "grep with -r flag",
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
expectedCmd: "grep -r pattern /path",
description: "grep flags should be passed to remote command",
},
{
name: "ps with aux flags",
args: []string{"hostname", "ps", "aux"},
expectedCmd: "ps aux",
description: "ps flags should be passed to remote command",
},
{
name: "command with double dash",
args: []string{"hostname", "--", "ls", "-la"},
expectedCmd: "-- ls -la",
description: "double dash should be preserved in command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
// Test that commands with arguments should execute the command and exit,
// not drop to an interactive shell
tests := []struct {
name string
args []string
expectedCmd string
shouldExit bool
description string
}{
{
name: "ls command should execute and exit",
args: []string{"hostname", "ls"},
expectedCmd: "ls",
shouldExit: true,
description: "ls command should execute and exit, not drop to shell",
},
{
name: "ls with flags should execute and exit",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
shouldExit: true,
description: "ls with flags should execute and exit, not drop to shell",
},
{
name: "pwd command should execute and exit",
args: []string{"hostname", "pwd"},
expectedCmd: "pwd",
shouldExit: true,
description: "pwd command should execute and exit, not drop to shell",
},
{
name: "echo command should execute and exit",
args: []string{"hostname", "echo", "hello"},
expectedCmd: "echo hello",
shouldExit: true,
description: "echo command should execute and exit, not drop to shell",
},
{
name: "no command should open shell",
args: []string{"hostname"},
expectedCmd: "",
shouldExit: false,
description: "no command should open interactive shell",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// When command is present, it should execute the command and exit
// When command is empty, it should open interactive shell
hasCommand := command != ""
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
})
}
}
func TestSSHCommand_FlagHandling(t *testing.T) {
// Test that flags after hostname are not parsed by netbird but passed to SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "ls with -la flag should not be parsed by netbird",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
},
{
name: "command with netbird-like flags should be passed through",
args: []string{"hostname", "echo", "--help"},
expectedHost: "hostname",
expectedCmd: "echo --help",
expectError: false,
description: "--help should be passed to echo, not parsed by netbird",
},
{
name: "command with -p flag should not conflict with SSH port flag",
args: []string{"hostname", "ps", "-p", "1234"},
expectedHost: "hostname",
expectedCmd: "ps -p 1234",
expectError: false,
description: "ps -p should be passed to ps command, not parsed as port",
},
{
name: "tar with flags should be passed through",
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
expectedHost: "hostname",
expectedCmd: "tar -czf backup.tar.gz /home",
expectError: false,
description: "tar flags should be passed to tar command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
// should not parse -la as netbird flags but pass them to the SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "original issue: ls -la should be preserved",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "The original failing case should now work",
},
{
name: "ls -l should be preserved",
args: []string{"hostname", "ls", "-l"},
expectedHost: "hostname",
expectedCmd: "ls -l",
expectError: false,
description: "Single letter flags should be preserved",
},
{
name: "SSH port flag should work",
args: []string{"-p", "2222", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedCmd: "ls -la",
expectError: false,
description: "SSH -p flag should be parsed, command flags preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// Check port for the test case with -p flag
if len(tt.args) > 0 && tt.args[0] == "-p" {
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
}
})
}
}
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectError bool
description string
}{
{
name: "local port forwarding -L",
args: []string{"-L", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Single -L flag should be parsed correctly",
},
{
name: "remote port forwarding -R",
args: []string{"-R", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectError: false,
description: "Single -R flag should be parsed correctly",
},
{
name: "multiple local port forwards",
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
expectedRemote: []string{},
expectError: false,
description: "Multiple -L flags should be parsed correctly",
},
{
name: "multiple remote port forwards",
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
expectError: false,
description: "Multiple -R flags should be parsed correctly",
},
{
name: "mixed local and remote forwards",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectError: false,
description: "Mixed -L and -R flags should be parsed correctly",
},
{
name: "port forwarding with bind address",
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with bind address should work",
},
{
name: "port forwarding with command",
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with command should work",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
})
}
}
func TestParsePortForward(t *testing.T) {
tests := []struct {
name string
spec string
expectedLocal string
expectedRemote string
expectError bool
description string
}{
{
name: "simple port forward",
spec: "8080:localhost:80",
expectedLocal: "localhost:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Simple port:host:port format should work",
},
{
name: "port forward with bind address",
spec: "127.0.0.1:8080:localhost:80",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "bind_address:port:host:port format should work",
},
{
name: "port forward to different host",
spec: "8080:example.com:443",
expectedLocal: "localhost:8080",
expectedRemote: "example.com:443",
expectError: false,
description: "Forwarding to different host should work",
},
{
name: "port forward with IPv6 (needs bracket support)",
spec: "::1:8080:localhost:80",
expectError: true,
description: "IPv6 without brackets fails as expected (feature to implement)",
},
{
name: "invalid format - too few parts",
spec: "8080:localhost",
expectError: true,
description: "Invalid format with too few parts should fail",
},
{
name: "invalid format - too many parts",
spec: "127.0.0.1:8080:localhost:80:extra",
expectError: true,
description: "Invalid format with too many parts should fail",
},
{
name: "empty spec",
spec: "",
expectError: true,
description: "Empty spec should fail",
},
{
name: "unix socket local forward",
spec: "8080:/tmp/socket",
expectedLocal: "localhost:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket forwarding should work",
},
{
name: "unix socket with bind address",
spec: "127.0.0.1:8080:/tmp/socket",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket with bind address should work",
},
{
name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
},
{
name: "wildcard for port only",
spec: "8080:*:80",
expectedLocal: "localhost:8080",
expectedRemote: "*:80",
expectError: false,
description: "Wildcard in remote host should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
if tt.expectError {
assert.Error(t, err, tt.description)
return
}
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
})
}
}
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
// Integration test for port forwarding with the actual SSH command implementation
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectedCmd string
description string
}{
{
name: "local forward with command",
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectedCmd: "echo test",
description: "Local forwarding should work with commands",
},
{
name: "remote forward with command",
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectedCmd: "ls -la",
description: "Remote forwarding should work with commands",
},
{
name: "multiple forwards with user and command",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectedCmd: "ps aux",
description: "Complex case with multiple forwards, user, and command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
})
}
}
func TestSSHCommand_ParameterIsolation(t *testing.T) {
tests := []struct {
name string
args []string
expectedCmd string
}{
{
name: "cmd flag passed as command",
args: []string{"hostname", "--cmd", "echo test"},
expectedCmd: "--cmd echo test",
},
{
name: "uid flag passed as command",
args: []string{"hostname", "--uid", "1000"},
expectedCmd: "--uid 1000",
},
{
name: "shell flag passed as command",
args: []string{"hostname", "--shell", "/bin/bash"},
expectedCmd: "--shell /bin/bash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
require.NoError(t, err)
assert.Equal(t, "hostname", host)
assert.Equal(t, tt.expectedCmd, command)
})
}
}
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
}
if err != nil {

View File

@@ -12,8 +12,11 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto"
@@ -23,8 +26,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -115,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
@@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
@@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}

View File

@@ -18,12 +18,16 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
var (
ErrClientAlreadyStarted = errors.New("client already started")
ErrClientNotStarted = errors.New("client not started")
ErrEngineNotStarted = errors.New("engine not started")
ErrConfigNotInitialized = errors.New("config not initialized")
)
// Client manages a netbird embedded client instance.
type Client struct {
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
engine, err := c.getEngine()
if err != nil {
return nil, err
}
nsnet, err := engine.GetNet()
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// DialContext dials a network address in the netbird network with context
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return c.Dial(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
engine, err := c.getEngine()
if err != nil {
return err
}
storedKey, found := engine.GetPeerSSHKey(peerAddress)
if !found {
return sshcommon.ErrPeerNotFound
}
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.
func (c *Client) getEngine() (*internal.Engine, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
if connect == nil {
return nil, ErrClientNotStarted
}
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
return nil, ErrEngineNotStarted
}
return engine, nil
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
engine, err := c.getEngine()
if err != nil {
return nil, netip.Addr{}, err
}
addr, err := engine.Address()

View File

@@ -27,7 +27,11 @@ import (
)
const (
tableNat = "nat"
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("unable to list tables: %v", err)
return nil, fmt.Errorf("list tables: %w", err)
}
for _, table := range tables {
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
log.Errorf("failed to refresh rules: %s", err)
}
return nil
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
func (r *router) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
Exprs: exprsOut,
})
return nil
return r.conn.Flush()
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables()
return r.acceptFilterRulesNftables(r.filterTable)
}
return r.acceptFilterRulesIptables(ipt)
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
Data: intf,
},
}
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return nil
}
func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptFilterRulesNftables()
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
if chain.Table.Name != table.Name {
continue
}
@@ -1100,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
return fmt.Errorf("get rules: %v", err)
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
return chains
}
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
return nil
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1128,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf(" unable to list rules: %v", err)
return fmt.Errorf("list rules: %w", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {

View File

@@ -35,6 +35,12 @@ const (
ipTCPHeaderMinSize = 40
)
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
@@ -59,12 +65,6 @@ const (
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
return buf.Bytes()
}
func TestShouldForward(t *testing.T) {
// Set up test addresses
wgIP := netip.MustParseAddr("100.10.0.1")
otherIP := netip.MustParseAddr("100.10.0.2")
// Create test manager with mock interface
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// Set the mock to return our test WG IP
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Helper to create decoder with TCP packet
createTCPDecoder := func(dstPort uint16) *decoder {
ipv4 := &layers.IPv4{
Version: 4,
Protocol: layers.IPProtocolTCP,
SrcIP: net.ParseIP("192.168.1.100"),
DstIP: wgIP.AsSlice(),
}
tcp := &layers.TCP{
SrcPort: 54321,
DstPort: layers.TCPPort(dstPort),
}
err := tcp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
require.NoError(t, err)
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
require.NoError(t, err)
return d
}
tests := []struct {
name string
localForwarding bool
netstack bool
dstIP netip.Addr
serviceRegistered bool
servicePort uint16
expected bool
description string
}{
{
name: "no local forwarding",
localForwarding: false,
netstack: true,
dstIP: wgIP,
expected: false,
description: "should never forward when local forwarding disabled",
},
{
name: "traffic to other local interface",
localForwarding: true,
netstack: false,
dstIP: otherIP,
expected: true,
description: "should forward traffic to our other local interfaces (not NetBird IP)",
},
{
name: "traffic to NetBird IP, no netstack",
localForwarding: true,
netstack: false,
dstIP: wgIP,
expected: false,
description: "should send to netstack listeners (final return false path)",
},
{
name: "traffic to our IP, netstack mode, no service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
expected: true,
description: "should forward when in netstack mode with no matching service",
},
{
name: "traffic to our IP, netstack mode, with service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
serviceRegistered: true,
servicePort: 22,
expected: false,
description: "should send to netstack listeners when service is registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Configure manager
manager.localForwarding = tt.localForwarding
manager.netstack = tt.netstack
// Register service if needed
if tt.serviceRegistered {
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
}
// Create decoder for the test
decoder := createTCPDecoder(tt.servicePort)
if !tt.serviceRegistered {
decoder = createTCPDecoder(8080) // Use non-registered port
}
// Test the method
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
}

View File

@@ -0,0 +1,85 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestPortDNATBasic tests basic port DNAT functionality
func TestPortDNATBasic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rule for peer B (the target)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
d := parsePacket(t, packetAtoB)
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
// Verify port was translated to 22022
d = parsePacket(t, packetAtoB)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
// (prevents double NAT - original port stored in conntrack)
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
d2 := parsePacket(t, returnPacket)
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
}
// TestPortDNATMultipleRules tests multiple port DNAT rules
func TestPortDNATMultipleRules(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rules for both peers
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Test traffic to peer B gets translated
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
d1 := parsePacket(t, packetToB)
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
require.True(t, translatedToB, "Traffic to peer B should be translated")
d1 = parsePacket(t, packetToB)
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
// Test traffic to peer A gets translated
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
d2 := parsePacket(t, packetToA)
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
require.True(t, translatedToA, "Traffic to peer A should be translated")
d2 = parsePacket(t, packetToA)
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"runtime"
"time"
@@ -12,7 +11,6 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -20,9 +18,6 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
@@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}))
}
conn, err := grpc.NewClient(
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
return nil, fmt.Errorf("dial context: %w", err)
}
return conn, nil

View File

@@ -3,6 +3,7 @@
package device
import (
"fmt"
"strings"
log "github.com/sirupsen/logrus"
@@ -19,11 +20,12 @@ import (
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
// todo: review if we can eliminate the TunAdapter
tunAdapter TunAdapter
disableDNS bool
@@ -32,17 +34,19 @@ type WGTunDevice struct {
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
renewableTun *RenewableTUN
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
renewableTun: NewRenewableTUN(),
}
}
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err
}
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to create Android interface: %s", err)
return nil, err
}
t.renewableTun.AddDevice(unmonitoredTUN)
t.name = name
t.filteredDevice = newDeviceFilter(tunDevice)
t.filteredDevice = newDeviceFilter(t.renewableTun)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) RenewTun(fd int) error {
if t.device == nil {
return fmt.Errorf("device not initialized")
}
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to renew Android interface: %s", err)
return err
}
t.renewableTun.AddDevice(unmonitoredTUN)
return nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement
return nil

View File

@@ -2,6 +2,13 @@
package device
import "fmt"
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}
func (t *TunNetstackDevice) RenewTun(fd int) error {
// Doesn't make sense in Android for Netstack.
return fmt.Errorf("this function has not been implemented in Netstack for Android")
}

View File

@@ -0,0 +1,309 @@
//go:build android
package device
import (
"io"
"os"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
//
// It also redirects tun.Device's Events() to a separate goroutine
// and closes it when Close is called.
//
// The WaitGroup and CloseOnce fields are used to ensure that the
// goroutine is awaited and closed only once.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel (RenewableTUN's events channel).
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
}
type RenewableTUN struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
events chan tun.Event
closed atomic.Bool
}
func NewRenewableTUN() *RenewableTUN {
r := &RenewableTUN{
devices: make([]*closeAwareDevice, 0),
mu: sync.Mutex{},
events: make(chan tun.Event, 16),
}
r.cond = sync.NewCond(&r.mu)
return r
}
func (r *RenewableTUN) File() *os.File {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// Read reads from an underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries reading from the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
for {
dev := r.peekLast()
if dev == nil {
// wait until AddDevice() signals a new device via cond.Broadcast()
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
return 0, io.EOF
}
continue
}
n, err = dev.Read(bufs, sizes, offset)
if err == nil {
return n, nil
}
// swap in progress; retry on the newest instead of returning the error
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err // propagate non-swap error
}
}
// Write writes to underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries writing to the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}
func (r *RenewableTUN) MTU() (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
func (r *RenewableTUN) Name() (string, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// Events returns a channel that is fed events from the underlying tun.Device's events channel
// once it is added.
func (r *RenewableTUN) Events() <-chan tun.Event {
return r.events
}
func (r *RenewableTUN) Close() error {
// Attempts to set the RenewableTUN closed flag to true.
// If it's already true, returns immediately.
if !r.closed.CompareAndSwap(false, true) {
return nil // already closed: idempotent
}
r.mu.Lock()
devices := r.devices
r.devices = nil
r.cond.Broadcast()
r.mu.Unlock()
var lastErr error
log.Debugf("closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
log.Debugf("error closing a device: %v", err)
lastErr = err
}
}
close(r.events)
return lastErr
}
func (r *RenewableTUN) BatchSize() int {
return 1
}
func (r *RenewableTUN) AddDevice(device tun.Device) {
r.mu.Lock()
if r.closed.Load() {
r.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(r.devices) > 0 {
toClose = r.devices[len(r.devices)-1]
}
cad := newClosableDevice(device)
cad.redirectEvents(r.events)
r.devices = []*closeAwareDevice{cad}
r.cond.Broadcast()
r.mu.Unlock()
if toClose != nil {
if err := toClose.Close(); err != nil {
log.Debugf("error closing last device: %v", err)
}
}
}
func (r *RenewableTUN) waitForDevice() bool {
r.mu.Lock()
defer r.mu.Unlock()
for len(r.devices) == 0 && !r.closed.Load() {
r.cond.Wait()
}
return !r.closed.Load()
}
func (r *RenewableTUN) peekLast() *closeAwareDevice {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.devices) == 0 {
return nil
}
return r.devices[len(r.devices)-1]
}

View File

@@ -21,5 +21,6 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
RenewTun(fd int) error
GetICEBind() device.EndpointManager
}

View File

@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on non-android")
}

View File

@@ -6,6 +6,7 @@ import (
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// todo: review does this function really necessary or can we merge it with iOS
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}
func (w *WGIface) RenewTun(fd int) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.RenewTun(fd)
}

View File

@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on this platform")
}

View File

@@ -17,7 +17,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled
// If SSH enabled, add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort),
})
}
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
// we have old version of management without rules handling, we should allow all traffic
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {

View File

@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
SshConfig: &mgmProto.SSHConfig{
SshEnabled: true,
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false)
expectedRules := 3
if fw.IsStateful() {
expectedRules = 3 // 2 inbound rules + SSH rule
}
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
}

View File

@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken
}
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
return authenticateWithDeviceCodeFlow(ctx, config, hint)
}

View File

@@ -13,6 +13,7 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
excludedRanges := getSystemExcludedPortRanges()
for _, redirectURL := range config.RedirectURLs {
if !isRedirectURLPortUsed(redirectURL) {
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
availableRedirectURL = redirectURL
break
}
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
}
if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
switch p.providerConfig.LoginFlag {
case common.LoginFlagPromptLogin:
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
case common.LoginFlagMaxAge0:
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
@@ -192,17 +195,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
if authErrorDesc != "" {
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
}
return nil, fmt.Errorf("authentication failed: %s", authError)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
return nil, fmt.Errorf("authentication failed: Invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
return nil, fmt.Errorf("authentication failed: missing code")
}
return p.oAuthConfig.Exchange(
@@ -231,7 +237,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
}
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
@@ -279,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
return base64.RawURLEncoding.EncodeToString(sha2[:])
}
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
func isRedirectURLPortUsed(redirectURL string) bool {
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
parsedURL, err := url.Parse(redirectURL)
if err != nil {
log.Errorf("failed to parse redirect URL: %v", err)
return true
}
addr := fmt.Sprintf(":%s", parsedURL.Port())
port := parsedURL.Port()
if isPortInExcludedRange(port, excludedRanges) {
log.Warnf("port %s is in Windows excluded port range, skipping", port)
return true
}
addr := fmt.Sprintf(":%s", port)
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false
@@ -301,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
return true
}
// excludedPortRange represents a range of excluded ports.
type excludedPortRange struct {
start int
end int
}
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
if len(excludedRanges) == 0 {
return false
}
portNum, err := strconv.Atoi(port)
if err != nil {
log.Debugf("invalid port number %s: %v", port, err)
return false
}
for _, r := range excludedRanges {
if portNum >= r.start && portNum <= r.end {
return true
}
}
return false
}
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
if err != nil {

View File

@@ -0,0 +1,8 @@
//go:build !windows
package auth
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
func getSystemExcludedPortRanges() []excludedPortRange {
return nil
}

View File

@@ -2,8 +2,11 @@ package auth
import (
"context"
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
expectContains []string
}{
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
name: "Prompt login",
loginFlag: mgm.LoginFlagPromptLogin,
expectContains: []string{promptLogin},
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
name: "Max age 0",
loginFlag: mgm.LoginFlagMaxAge0,
expectContains: []string{maxAge0},
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
loginFlag: mgm.LoginFlagPromptLogin,
disablePromptLogin: true,
expectContains: []string{},
},
{
name: "None flag should not add parameters",
loginFlag: mgm.LoginFlagNone,
expectContains: []string{},
},
}
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
LoginFlag: tc.loginFlag,
DisablePromptLogin: tc.disablePromptLogin,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
t.Fatalf("Failed to request auth info: %v", err)
}
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
for _, expected := range tc.expectContains {
require.Contains(t, authInfo.VerificationURIComplete, expected)
}
})
}
}
func TestIsPortInExcludedRange(t *testing.T) {
tests := []struct {
name string
port string
excludedRanges []excludedPortRange
expectedBlocked bool
}{
{
name: "Port in excluded range",
port: "8080",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: true,
},
{
name: "Port at start of range",
port: "8000",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: true,
},
{
name: "Port at end of range",
port: "8100",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: true,
},
{
name: "Port before range",
port: "7999",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
{
name: "Port after range",
port: "8101",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
{
name: "Empty excluded ranges",
port: "8080",
excludedRanges: []excludedPortRange{},
expectedBlocked: false,
},
{
name: "Nil excluded ranges",
port: "8080",
excludedRanges: nil,
expectedBlocked: false,
},
{
name: "Multiple ranges - port in second range",
port: "9050",
excludedRanges: []excludedPortRange{
{start: 8000, end: 8100},
{start: 9000, end: 9100},
},
expectedBlocked: true,
},
{
name: "Multiple ranges - port not in any range",
port: "8500",
excludedRanges: []excludedPortRange{
{start: 8000, end: 8100},
{start: 9000, end: 9100},
},
expectedBlocked: false,
},
{
name: "Invalid port string",
port: "invalid",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
{
name: "Empty port string",
port: "",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
})
}
}
func TestIsRedirectURLPortUsed(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = listener.Close()
}()
usedPort := listener.Addr().(*net.TCPAddr).Port
tests := []struct {
name string
redirectURL string
excludedRanges []excludedPortRange
expectedUsed bool
}{
{
name: "Port in excluded range",
redirectURL: "http://127.0.0.1:8080/",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedUsed: true,
},
{
name: "Port actually in use",
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
excludedRanges: nil,
expectedUsed: true,
},
{
name: "Port not in use and not excluded",
redirectURL: "http://127.0.0.1:65432/",
excludedRanges: nil,
expectedUsed: false,
},
{
name: "Invalid URL without port",
redirectURL: "not-a-valid-url",
excludedRanges: nil,
expectedUsed: false,
},
{
name: "Port excluded even if not in use",
redirectURL: "http://127.0.0.1:8050/",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedUsed: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
})
}
}

View File

@@ -0,0 +1,86 @@
//go:build windows
package auth
import (
"bufio"
"fmt"
"os/exec"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
func getSystemExcludedPortRanges() []excludedPortRange {
ranges, err := getExcludedPortRangesFromNetsh()
if err != nil {
log.Debugf("failed to get Windows excluded port ranges: %v", err)
return nil
}
return ranges
}
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("netsh command: %w", err)
}
return parseExcludedPortRanges(string(output))
}
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
var ranges []excludedPortRange
scanner := bufio.NewScanner(strings.NewReader(output))
foundHeader := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
foundHeader = true
continue
}
if !foundHeader {
continue
}
if strings.Contains(line, "----------") {
continue
}
if line == "" {
continue
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
startPort, err := strconv.Atoi(fields[0])
if err != nil {
continue
}
endPort, err := strconv.Atoi(fields[1])
if err != nil {
continue
}
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("scan output: %w", err)
}
return ranges, nil
}

View File

@@ -0,0 +1,116 @@
//go:build windows
package auth
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
tests := []struct {
name string
netshOutput string
expectedRanges []excludedPortRange
expectError bool
}{
{
name: "Valid netsh output with multiple ranges",
netshOutput: `
Protocol tcp Dynamic Port Range
---------------------------------
Start Port : 49152
Number of Ports : 16384
Protocol tcp Excluded Port Ranges
---------------------------------
Start Port End Port
---------- --------
5357 5357 *
50000 50059 *
`,
expectedRanges: []excludedPortRange{
{start: 5357, end: 5357},
{start: 50000, end: 50059},
},
expectError: false,
},
{
name: "Empty output",
netshOutput: `
Protocol tcp Dynamic Port Range
---------------------------------
Start Port : 49152
Number of Ports : 16384
`,
expectedRanges: nil,
expectError: false,
},
{
name: "Single range",
netshOutput: `
Protocol tcp Excluded Port Ranges
---------------------------------
Start Port End Port
---------- --------
8080 8090
`,
expectedRanges: []excludedPortRange{
{start: 8080, end: 8090},
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ranges, err := parseExcludedPortRanges(tt.netshOutput)
if tt.expectError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedRanges, ranges)
}
})
}
}
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
ranges := getSystemExcludedPortRanges()
t.Logf("Found %d excluded port ranges on this system", len(ranges))
listener1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = listener1.Close()
}()
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
availablePort := 65432
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
},
UseIDToken: true,
}
flow, err := NewPKCEAuthorizationFlow(config)
require.NoError(t, err)
require.NotNil(t, flow)
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
"Should skip port in use and select available port")
}

View File

@@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid(
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid(
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
}
@@ -271,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -291,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
<-engineCtx.Done()
c.engineMutex.Lock()
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
@@ -416,20 +421,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
@@ -515,6 +525,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -453,6 +453,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
}
if g.internalConfig.EnableSSHRoot != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
}
if g.internalConfig.EnableSSHSFTP != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
}
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
}
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))

View File

@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
continue
}
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
@@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
Domain: zoneName,
Records: records,
SearchDomainDisabled: true,
}
config.CustomZones = append(config.CustomZones, reverseZone)

View File

@@ -11,11 +11,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
const (
ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa."
)
type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
}
for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
MatchOnly: matchOnly,
MatchOnly: customZone.SearchDomainDisabled,
})
}

View File

@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
logger.Warn(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)

View File

@@ -9,7 +9,6 @@ import (
"net/netip"
"net/url"
"os"
"reflect"
"runtime"
"slices"
"sort"
@@ -30,7 +29,6 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -51,10 +49,10 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
@@ -115,7 +113,12 @@ type EngineConfig struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerSSHAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DNSRouteInterval time.Duration
@@ -148,8 +151,6 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex
// sshMux protects sshServer field access
sshMux sync.Mutex
config *EngineConfig
mobileDep MobileDependency
@@ -175,8 +176,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
sshServer sshServer
statusRecorder *peer.Status
@@ -246,7 +246,6 @@ func NewEngine(
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
@@ -256,7 +255,7 @@ func NewEngine(
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" {
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
@@ -268,6 +267,7 @@ func NewEngine(
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
return nil
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
@@ -292,8 +291,11 @@ func (e *Engine) Stop() error {
}
log.Info("Network monitor: stopped")
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if err := e.stopSSHServer(); err != nil {
log.Warnf("failed to stop SSH server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
@@ -302,24 +304,29 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil
}
e.stopDNSForwarder()
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
log.Errorf("failed to remove all peers: %s", err)
}
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
e.stopDNSForwarder()
// stop/restore DNS after peers are closed but before interface goes down
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
@@ -331,16 +338,18 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer stateCancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
if err := e.stateManager.Stop(stateCtx); err != nil {
log.Errorf("failed to stop state manager: %v", err)
}
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -426,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err)
}
err := e.rpManager.Run()
if err != nil {
if err := e.rpManager.Run(); err != nil {
return fmt.Errorf("run rosenpass manager: %w", err)
}
}
@@ -479,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -703,16 +712,10 @@ func (e *Engine) removeAllPeers() error {
return nil
}
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
// removePeer closes an existing peer connection and removes a peer
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)
e.sshMux.Lock()
if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.sshMux.Unlock()
e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey)
@@ -750,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -789,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
nm := update.GetNetworkMap()
if nm == nil {
if nm == nil || update.SkipNetworkMapUpdate {
return nil
}
@@ -884,6 +892,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -893,74 +906,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
return nil
}
func isNil(server nbssh.Server) bool {
return server == nil || reflect.ValueOf(server).IsNil()
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Infof("SSH server is disabled because inbound connections are blocked")
return nil
}
if !e.config.ServerSSHAllowed {
log.Info("SSH server is not enabled")
return nil
}
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
e.sshMux.Lock()
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
e.sshMux.Unlock()
return fmt.Errorf("create ssh server: %w", err)
}
e.sshServer = server
e.sshMux.Unlock()
go func() {
// blocking
err = server.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.sshMux.Lock()
e.sshServer = nil
e.sshMux.Unlock()
log.Infof("stopped SSH server")
}()
} else {
e.sshMux.Unlock()
log.Debugf("SSH server is already running")
}
} else {
e.sshMux.Lock()
if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
e.sshMux.Unlock()
}
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
@@ -973,8 +918,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
if conf.GetSshConfig() != nil {
err := e.updateSSH(conf.GetSshConfig())
if err != nil {
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
log.Warnf("failed handling SSH server setup: %v", err)
}
}
@@ -1012,9 +956,14 @@ func (e *Engine) receiveManagementEvents() {
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1170,19 +1119,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
e.sshMux.Lock()
if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
if err != nil {
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
}
}
}
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.sshMux.Unlock()
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1259,6 +1200,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
//nolint
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort
@@ -1273,7 +1215,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
@@ -1433,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
@@ -1544,15 +1493,6 @@ func (e *Engine) close() {
e.statusRecorder.SetWgIface(nil)
}
e.sshMux.Lock()
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed stopping the SSH server: %v", err)
}
}
e.sshMux.Unlock()
if e.firewall != nil {
err := e.firewall.Close(e.stateManager)
if err != nil {
@@ -1583,6 +1523,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)
@@ -1901,6 +1846,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
return e.wgInterface.Address().IP
}
func (e *Engine) RenewTun(fd int) error {
e.syncMsgMux.Lock()
wgInterface := e.wgInterface
e.syncMsgMux.Unlock()
if wgInterface == nil {
return fmt.Errorf("wireguard interface not initialized")
}
return wgInterface.RenewTun(fd)
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(
enabled bool,

View File

@@ -0,0 +1,355 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
"strings"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
}
func (e *Engine) setupSSHPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
return fmt.Errorf("add SSH port redirection: %w", err)
}
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked")
return e.stopSSHServer()
}
if !e.config.ServerSSHAllowed {
log.Info("SSH server is disabled in config")
return e.stopSSHServer()
}
if !sshConf.GetSshEnabled() {
if e.config.ServerSSHAllowed {
log.Info("SSH server is locally allowed but disabled by management server")
}
return e.stopSSHServer()
}
if e.sshServer != nil {
log.Debug("SSH server is already running")
return nil
}
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
return e.startSSHServer(nil)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
return e.startSSHServer(jwtConfig)
}
return errors.New("SSH server requires valid JWT configuration")
}
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
return nil
}
configMgr := sshconfig.New()
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
return nil // Don't fail engine startup on SSH config issues
}
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
SSHConfigDir: configMgr.GetSSHConfigDir(),
SSHConfigFile: configMgr.GetSSHConfigFile(),
}); err != nil {
log.Warnf("failed to update SSH config state: %v", err)
}
return nil
}
// extractPeerSSHInfo extracts SSH information from peer configurations
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
var peerInfo []sshconfig.PeerSSHInfo
for _, peerConfig := range remotePeers {
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
})
}
return peerInfo
}
// extractPeerIP extracts IP address from peer's allowed IPs
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
if len(peerConfig.GetAllowedIps()) == 0 {
return ""
}
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
return prefix.Addr().String()
}
return ""
}
// extractHostname extracts short hostname from FQDN
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
fqdn := peerConfig.GetFqdn()
if fqdn == "" {
return ""
}
parts := strings.Split(fqdn, ".")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers {
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
}
}
log.Debugf("updated peer SSH host keys for daemon API access")
}
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
e.syncMsgMux.Lock()
statusRecorder := e.statusRecorder
e.syncMsgMux.Unlock()
if statusRecorder == nil {
return nil, false
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
return nil, false
}
}
return nil, false
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err)
} else {
log.Debugf("SSH client config cleanup completed")
}
}
// startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
e.sshServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
}
}
if err := e.setupSSHPortRedirection(); err != nil {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
return nil
}
// configureSSHServer applies SSH configuration options to the server.
func (e *Engine) configureSSHServer(server *sshserver.Server) {
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
server.SetAllowRootLogin(true)
log.Info("SSH root login enabled")
} else {
server.SetAllowRootLogin(false)
log.Info("SSH root login disabled (default)")
}
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
server.SetAllowSFTP(true)
log.Info("SSH SFTP subsystem enabled")
} else {
server.SetAllowSFTP(false)
log.Info("SSH SFTP subsystem disabled (default)")
}
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
server.SetAllowLocalPortForwarding(true)
log.Info("SSH local port forwarding enabled")
} else {
server.SetAllowLocalPortForwarding(false)
log.Info("SSH local port forwarding disabled (default)")
}
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
server.SetAllowRemotePortForwarding(true)
log.Info("SSH remote port forwarding enabled")
} else {
server.SetAllowRemotePortForwarding(false)
log.Info("SSH remote port forwarding disabled (default)")
}
}
func (e *Engine) cleanupSSHPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
return fmt.Errorf("remove SSH port redirection: %w", err)
}
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
return nil
}
func (e *Engine) stopSSHServer() error {
if e.sshServer == nil {
return nil
}
if err := e.cleanupSSHPortRedirection(); err != nil {
log.Warnf("failed to cleanup SSH port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
}
}
log.Info("stopping SSH server")
err := e.sshServer.Stop()
e.sshServer = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
// GetSSHServerStatus returns the SSH server status and active sessions
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
e.syncMsgMux.Lock()
sshServer := e.sshServer
e.syncMsgMux.Unlock()
if sshServer == nil {
return false, nil
}
return sshServer.GetStatus()
}

View File

@@ -0,0 +1,79 @@
package internal
import (
"context"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun199",
WgAddr: "100.70.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
// Precondition
if engine.networkSerial != 0 {
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
}
resp := &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
SkipNetworkMapUpdate: true,
}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
if engine.networkSerial != 0 {
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
}
}
// Ensures handleSync exits early when NetworkMap is nil
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun198",
WgAddr: "100.70.0.2/24",
WgPrivateKey: key,
WgPort: 33101,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -25,14 +24,18 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -46,13 +49,12 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -108,6 +110,10 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time
}
func (m *MockWGIface) RenewTun(_ int) error {
return nil
}
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
@@ -214,11 +220,13 @@ func TestMain(m *testing.M) {
}
func TestEngine_SSH(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH")
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
key, err := wgtypes.GeneratePrivateKey()
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
@@ -240,6 +248,7 @@ func TestEngine_SSH(t *testing.T) {
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
@@ -250,35 +259,8 @@ func TestEngine_SSH(t *testing.T) {
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
var sshKeysAdded []string
var sshPeersRemoved []string
sshCtx, cancel := context.WithCancel(context.Background())
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
return &ssh.MockServer{
Ctx: sshCtx,
StopFunc: func() error {
cancel()
return nil
},
StartFunc: func() error {
<-ctx.Done()
return ctx.Err()
},
AddAuthorizedKeyFunc: func(peer, newKey string) error {
sshKeysAdded = append(sshKeysAdded, newKey)
return nil
},
RemoveAuthorizedKeyFunc: func(peer string) {
sshPeersRemoved = append(sshPeersRemoved, peer)
},
}, nil
}
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer func() {
err := engine.Stop()
@@ -304,9 +286,7 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
@@ -314,19 +294,24 @@ func TestEngine_SSH(t *testing.T) {
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
// now remove peer
networkMap = &mgmtProto.NetworkMap{
@@ -336,13 +321,10 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
@@ -354,12 +336,70 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: false, // Start with SSH disabled
},
syncMsgMux: &sync.Mutex{},
}
// Test SSH disabled config
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
err := engine.updateSSH(sshConfig)
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
// Test inbound blocked
engine.config.BlockInbound = true
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
engine.config.BlockInbound = false
// Test with server SSH not allowed
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHServerConsistency(t *testing.T) {
t.Run("server set only on successful creation", func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: true,
SSHKey: []byte("test-key"),
},
syncMsgMux: &sync.Mutex{},
}
engine.wgInterface = nil
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.Error(t, err)
assert.Nil(t, engine.sshServer)
})
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: false,
},
syncMsgMux: &sync.Mutex{},
}
err := engine.stopSSHServer()
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
})
}
func TestEngine_UpdateNetworkMap(t *testing.T) {
@@ -591,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -1588,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -20,6 +20,7 @@ import (
type wgIfaceBase interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
RenewTun(fd int) error
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address

View File

@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"os"
"runtime"
"strings"
)
@@ -10,5 +11,8 @@ const (
)
func isForceRelayed() bool {
if runtime.GOOS == "js" {
return true
}
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10
@@ -67,6 +67,7 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
SSHHostKey []byte
routes map[string]struct{}
}
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
return nil
}
// UpdatePeerSSHHostKey updates peer's SSH host key
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peerPubKey]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.SSHHostKey = sshHostKey
d.peers[peerPubKey] = peerState
return nil
}
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()

View File

@@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
ManagementURL string
AdminURL string
ConfigPath string
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
DisableClientRoutes *bool
DisableServerRoutes *bool
@@ -82,18 +88,24 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
ManagementURL *url.URL
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
PrivateKey string
PreSharedKey string
ManagementURL *url.URL
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
DisableServerRoutes bool
@@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
} else {
log.Infof("disabling SSH root login")
}
config.EnableSSHRoot = input.EnableSSHRoot
updated = true
}
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
if *input.EnableSSHSFTP {
log.Infof("enabling SSH SFTP subsystem")
} else {
log.Infof("disabling SSH SFTP subsystem")
}
config.EnableSSHSFTP = input.EnableSSHSFTP
updated = true
}
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
if *input.EnableSSHLocalPortForwarding {
log.Infof("enabling SSH local port forwarding")
} else {
log.Infof("disabling SSH local port forwarding")
}
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
updated = true
}
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
if *input.EnableSSHRemotePortForwarding {
log.Infof("enabling SSH remote port forwarding")
} else {
log.Infof("disabling SSH remote port forwarding")
}
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
updated = true
}
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
log.Infof("enabling SSH authentication")
}
config.DisableSSHAuth = input.DisableSSHAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())

View File

@@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) {
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
tests := []struct {
name string
wireguardPort *int
expectedPort int
description string
name string
wireguardPort *int
expectedPort int
description string
}{
{
name: "no port specified uses default",

View File

@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
return nil
}
// GetLoginHint retrieves the email from the active profile to use as login_hint.
func GetLoginHint() string {
pm := NewProfileManager()
activeProf, err := pm.GetActiveProfile()
if err != nil {
log.Debugf("failed to get active profile for login hint: %v", err)
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""
}
return profileState.Email
}

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -39,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/version"

View File

@@ -0,0 +1,218 @@
//go:build darwin && !ios
package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import (
"context"
"fmt"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
)
//export sleepCallbackBridge
func sleepCallbackBridge() {
log.Info("sleepCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeSleep)
}
}
//export resumedCallbackBridge
func resumedCallbackBridge() {
log.Info("resumedCallbackBridge event triggered")
}
//export suspendedCallbackBridge
func suspendedCallbackBridge() {
log.Info("suspendedCallbackBridge event triggered")
}
//export poweredOnCallbackBridge
func poweredOnCallbackBridge() {
log.Info("poweredOnCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeWakeUp)
}
}
type Detector struct {
callback func(event EventType)
ctx context.Context
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
}
func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
if _, exists := serviceRegistry[d]; exists {
return fmt.Errorf("detector service already registered")
}
d.callback = callback
d.ctx, d.cancel = context.WithCancel(context.Background())
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
return nil
}
serviceRegistry[d] = struct{}{}
// CFRunLoop must run on a single fixed OS thread
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.registerNotifications()
}()
log.Info("sleep detection service started on macOS")
return nil
}
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
// and the runloop is stopped and cleaned up.
func (d *Detector) Deregister() error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
_, exists := serviceRegistry[d]
if !exists {
return nil
}
// cancel and remove this detector
d.cancel()
delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 {
return nil
}
log.Info("sleep detection service stopping (deregister)")
// Deregister IOKit notifications, stop runloop, and free resources
C.unregisterNotifications()
return nil
}
func (d *Detector) triggerCallback(event EventType) {
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
cb := d.callback
go func(callback func(event EventType)) {
log.Info("sleep detection event fired")
callback(event)
close(doneChan)
}(cb)
select {
case <-doneChan:
case <-d.ctx.Done():
case <-timeout.C:
log.Warnf("sleep callback timed out")
}
}

View File

@@ -0,0 +1,9 @@
//go:build !darwin || ios
package sleep
import "fmt"
func NewDetector() (detector, error) {
return nil, fmt.Errorf("sleep not supported on this platform")
}

View File

@@ -0,0 +1,37 @@
package sleep
var (
EventTypeUnknown EventType = 0
EventTypeSleep EventType = 1
EventTypeWakeUp EventType = 2
)
type EventType int
type detector interface {
Register(callback func(eventType EventType)) error
Deregister() error
}
type Service struct {
detector detector
}
func New() (*Service, error) {
d, err := NewDetector()
if err != nil {
return nil, err
}
return &Service{
detector: d,
}, nil
}
func (s *Service) Register(callback func(eventType EventType)) error {
return s.detector.Register(callback)
}
func (s *Service) Deregister() error {
return s.detector.Deregister()
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,299 @@
package templates
import (
"html/template"
"os"
"path/filepath"
"testing"
)
func TestPKCEAuthMsgTemplate(t *testing.T) {
tests := []struct {
name string
data map[string]string
outputFile string
expectedTitle string
expectedInContent []string
notExpectedInContent []string
}{
{
name: "error_state",
data: map[string]string{
"Error": "authentication failed: invalid state",
},
outputFile: "pkce-auth-error.html",
expectedTitle: "Login Failed",
expectedInContent: []string{
"authentication failed: invalid state",
"Login Failed",
},
notExpectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird",
},
},
{
name: "success_state",
data: map[string]string{
// No error field means success
},
outputFile: "pkce-auth-success.html",
expectedTitle: "Login Successful",
expectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird. You can now close this window.",
},
notExpectedInContent: []string{
"Login Failed",
},
},
{
name: "error_state_timeout",
data: map[string]string{
"Error": "authentication timeout: request expired after 5 minutes",
},
outputFile: "pkce-auth-timeout.html",
expectedTitle: "Login Failed",
expectedInContent: []string{
"authentication timeout: request expired after 5 minutes",
"Login Failed",
},
notExpectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Parse the template
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
// Create temp directory for this test
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, tt.outputFile)
// Create output file
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
// Execute the template
if err := tmpl.Execute(file, tt.data); err != nil {
file.Close()
t.Fatalf("Failed to execute template: %v", err)
}
file.Close()
t.Logf("Generated test output: %s", outputPath)
// Read the generated file
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
contentStr := string(content)
// Verify file has content
if len(contentStr) == 0 {
t.Error("Output file is empty")
}
// Verify basic HTML structure
basicElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"NetBird",
}
for _, elem := range basicElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
// Verify expected title
if !contains(contentStr, tt.expectedTitle) {
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
}
// Verify expected content is present
for _, expected := range tt.expectedInContent {
if !contains(contentStr, expected) {
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
}
}
// Verify unexpected content is not present
for _, notExpected := range tt.notExpectedInContent {
if contains(contentStr, notExpected) {
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
}
}
})
}
}
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
// Test that the template can be parsed without errors
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Template parsing failed: %v", err)
}
// Test with empty data
t.Run("empty_data", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "empty-data.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
if err := tmpl.Execute(file, nil); err != nil {
t.Errorf("Template execution with nil data failed: %v", err)
}
})
// Test with error data
t.Run("with_error", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "with-error.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
data := map[string]string{
"Error": "test error message",
}
if err := tmpl.Execute(file, data); err != nil {
t.Errorf("Template execution with error data failed: %v", err)
}
})
}
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
// Test that the template contains expected elements
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Template parsing failed: %v", err)
}
t.Run("success_content", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "success.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
data := map[string]string{}
if err := tmpl.Execute(file, data); err != nil {
t.Fatalf("Template execution failed: %v", err)
}
// Read the file and verify it contains expected content
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
// Check for success indicators
contentStr := string(content)
if len(contentStr) == 0 {
t.Error("Generated HTML is empty")
}
// Basic HTML structure checks
requiredElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"Login Successful",
"NetBird",
}
for _, elem := range requiredElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
})
t.Run("error_content", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "error.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
errorMsg := "test error message"
data := map[string]string{
"Error": errorMsg,
}
if err := tmpl.Execute(file, data); err != nil {
t.Fatalf("Template execution failed: %v", err)
}
// Read the file and verify it contains expected content
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
// Check for error indicators
contentStr := string(content)
if len(contentStr) == 0 {
t.Error("Generated HTML is empty")
}
// Basic HTML structure checks
requiredElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"Login Failed",
errorMsg,
}
for _, elem := range requiredElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
})
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -1,9 +1,12 @@
//go:build ios
package NetBirdSDK
import (
"context"
"fmt"
"net/netip"
"os"
"sort"
"strings"
"sync"
@@ -20,8 +23,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error {
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile,
})
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
if err != nil {
return err.Error()
}
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
}
return netIDs
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
log.Debugf("Setting env variable %s: %s", k, v)
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
} else {
log.Debugf("Env variable %s was set successfully", k)
}
}
}

View File

@@ -0,0 +1,34 @@
//go:build ios
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import _ "golang.org/x/mobile/bind"

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ service DaemonService {
// Status of the service.
rpc Status(StatusRequest) returns (StatusResponse) {}
// Down engine work in the daemon.
// Down stops engine work in the daemon.
rpc Down(DownRequest) returns (DownResponse) {}
// GetConfig of the daemon.
@@ -84,9 +84,35 @@ service DaemonService {
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
// RequestJWTAuth initiates JWT authentication flow for SSH
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
}
message OSLifecycleRequest {
// avoid collision with loglevel enum
enum CycleType {
UNKNOWN = 0;
SLEEP = 1;
WAKEUP = 2;
}
CycleType type = 1;
}
message OSLifecycleResponse {}
message LoginRequest {
// setupKey netbird setup key.
string setupKey = 1;
@@ -161,6 +187,13 @@ message LoginRequest {
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
optional bool enableSSHRoot = 34;
optional bool enableSSHSFTP = 35;
optional bool enableSSHLocalPortForwarding = 36;
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
}
message LoginResponse {
@@ -188,9 +221,9 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3;
optional bool waitForReady = 3;
}
message StatusResponse{
@@ -255,6 +288,18 @@ message GetConfigResponse {
bool disable_server_routes = 19;
bool block_lan_access = 20;
bool enableSSHRoot = 21;
bool enableSSHSFTP = 24;
bool enableSSHLocalPortForwarding = 22;
bool enableSSHRemotePortForwarding = 23;
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
}
// PeerState contains the latest state of a peer
@@ -276,6 +321,7 @@ message PeerState {
repeated string networks = 16;
google.protobuf.Duration latency = 17;
string relayAddress = 18;
bytes sshHostKey = 19;
}
// LocalPeerState contains the latest state of the local peer
@@ -317,6 +363,20 @@ message NSGroupState {
string error = 4;
}
// SSHSessionInfo contains information about an active SSH session
message SSHSessionInfo {
string username = 1;
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
}
// SSHServerState contains the latest state of the SSH server
message SSHServerState {
bool enabled = 1;
repeated SSHSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -330,6 +390,7 @@ message FullStatus {
repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
}
// Networks
@@ -543,56 +604,63 @@ message SwitchProfileRequest {
message SwitchProfileResponse {}
message SetConfigRequest {
string username = 1;
string profileName = 2;
// managementUrl to authenticate.
string managementUrl = 3;
string username = 1;
string profileName = 2;
// managementUrl to authenticate.
string managementUrl = 3;
// adminUrl to manage keys.
string adminURL = 4;
// adminUrl to manage keys.
string adminURL = 4;
optional bool rosenpassEnabled = 5;
optional bool rosenpassEnabled = 5;
optional string interfaceName = 6;
optional string interfaceName = 6;
optional int64 wireguardPort = 7;
optional int64 wireguardPort = 7;
optional string optionalPreSharedKey = 8;
optional string optionalPreSharedKey = 8;
optional bool disableAutoConnect = 9;
optional bool disableAutoConnect = 9;
optional bool serverSSHAllowed = 10;
optional bool serverSSHAllowed = 10;
optional bool rosenpassPermissive = 11;
optional bool rosenpassPermissive = 11;
optional bool networkMonitor = 12;
optional bool networkMonitor = 12;
optional bool disable_client_routes = 13;
optional bool disable_server_routes = 14;
optional bool disable_dns = 15;
optional bool disable_firewall = 16;
optional bool block_lan_access = 17;
optional bool disable_client_routes = 13;
optional bool disable_server_routes = 14;
optional bool disable_dns = 15;
optional bool disable_firewall = 16;
optional bool block_lan_access = 17;
optional bool disable_notifications = 18;
optional bool disable_notifications = 18;
optional bool lazyConnectionEnabled = 19;
optional bool lazyConnectionEnabled = 19;
optional bool block_inbound = 20;
optional bool block_inbound = 20;
repeated string natExternalIPs = 21;
bool cleanNATExternalIPs = 22;
repeated string natExternalIPs = 21;
bool cleanNATExternalIPs = 22;
bytes customDNSAddress = 23;
bytes customDNSAddress = 23;
repeated string extraIFaceBlacklist = 24;
repeated string extraIFaceBlacklist = 24;
repeated string dns_labels = 25;
// cleanDNSLabels clean map list of DNS labels.
bool cleanDNSLabels = 26;
repeated string dns_labels = 25;
// cleanDNSLabels clean map list of DNS labels.
bool cleanDNSLabels = 26;
optional google.protobuf.Duration dnsRouteInterval = 27;
optional google.protobuf.Duration dnsRouteInterval = 27;
optional int64 mtu = 28;
optional int64 mtu = 28;
optional bool enableSSHRoot = 29;
optional bool enableSSHSFTP = 30;
optional bool enableSSHLocalPortForwarding = 31;
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
}
message SetConfigResponse{}
@@ -644,3 +712,63 @@ message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
}
// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
message GetPeerSSHHostKeyRequest {
// peer IP address or FQDN to get SSH host key for
string peerAddress = 1;
}
// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
message GetPeerSSHHostKeyResponse {
// SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
bytes sshHostKey = 1;
// peer IP address
string peerIP = 2;
// peer FQDN
string peerFQDN = 3;
// indicates if the SSH host key was found
bool found = 4;
}
// RequestJWTAuthRequest for initiating JWT authentication flow
message RequestJWTAuthRequest {
// hint for OIDC login_hint parameter (typically email address)
optional string hint = 1;
}
// RequestJWTAuthResponse contains authentication flow information
message RequestJWTAuthResponse {
// verification URI for user authentication
string verificationURI = 1;
// complete verification URI (with embedded user code)
string verificationURIComplete = 2;
// user code to enter on verification URI
string userCode = 3;
// device code for polling
string deviceCode = 4;
// expiration time in seconds
int64 expiresIn = 5;
// if a cached token is available, it will be returned here
string cachedToken = 6;
// maximum age of JWT tokens in seconds (from management server)
int64 maxTokenAge = 7;
}
// WaitJWTTokenRequest for waiting for authentication completion
message WaitJWTTokenRequest {
// device code from RequestJWTAuthResponse
string deviceCode = 1;
// user code for verification
string userCode = 2;
}
// WaitJWTTokenResponse contains the JWT token after authentication
message WaitJWTTokenResponse {
// JWT token (access token or ID token)
string token = 1;
// token type (e.g., "Bearer")
string tokenType = 2;
// expiration time in seconds
int64 expiresIn = 3;
}

View File

@@ -27,7 +27,7 @@ type DaemonServiceClient interface {
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
// Status of the service.
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
// Down engine work in the daemon.
// Down stops engine work in the daemon.
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
@@ -64,6 +64,13 @@ type DaemonServiceClient interface {
// Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
}
type daemonServiceClient struct {
@@ -349,6 +356,42 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe
return out, nil
}
func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) {
out := new(GetPeerSSHHostKeyResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
out := new(RequestJWTAuthResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
out := new(WaitJWTTokenResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
out := new(OSLifecycleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -362,7 +405,7 @@ type DaemonServiceServer interface {
Up(context.Context, *UpRequest) (*UpResponse, error)
// Status of the service.
Status(context.Context, *StatusRequest) (*StatusResponse, error)
// Down engine work in the daemon.
// Down stops engine work in the daemon.
Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
@@ -399,6 +442,13 @@ type DaemonServiceServer interface {
// Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -490,6 +540,18 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest)
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
}
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
}
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
}
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1010,6 +1072,78 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetPeerSSHHostKeyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RequestJWTAuthRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(WaitJWTTokenRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/WaitJWTToken",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(OSLifecycleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/NotifyOSLifecycle",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1125,6 +1259,22 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetFeatures",
Handler: _DaemonService_GetFeatures_Handler,
},
{
MethodName: "GetPeerSSHHostKey",
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
},
{
MethodName: "RequestJWTAuth",
Handler: _DaemonService_RequestJWTAuth_Handler,
},
{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -0,0 +1,79 @@
package server
import (
"sync"
"time"
"github.com/awnumar/memguard"
log "github.com/sirupsen/logrus"
)
type jwtCache struct {
mu sync.RWMutex
enclave *memguard.Enclave
expiresAt time.Time
timer *time.Timer
maxTokenSize int
}
func newJWTCache() *jwtCache {
return &jwtCache{
maxTokenSize: 8192,
}
}
func (c *jwtCache) store(token string, maxAge time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanup()
if c.timer != nil {
c.timer.Stop()
}
tokenBytes := []byte(token)
c.enclave = memguard.NewEnclave(tokenBytes)
c.expiresAt = time.Now().Add(maxAge)
var timer *time.Timer
timer = time.AfterFunc(maxAge, func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.timer != timer {
return
}
c.cleanup()
c.timer = nil
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
})
c.timer = timer
}
func (c *jwtCache) get() (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.enclave == nil || time.Now().After(c.expiresAt) {
return "", false
}
buffer, err := c.enclave.Open()
if err != nil {
log.Debugf("Failed to open JWT token enclave: %v", err)
return "", false
}
defer buffer.Destroy()
token := string(buffer.Bytes())
return token, true
}
// cleanup destroys the secure enclave, must be called with lock held
func (c *jwtCache) cleanup() {
if c.enclave != nil {
c.enclave = nil
}
c.expiresAt = time.Time{}
}

View File

@@ -0,0 +1,77 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
return s.handleWakeUp(callerCtx)
case proto.OSLifecycleRequest_SLEEP:
return s.handleSleep(callerCtx)
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
if !s.sleepTriggeredDown.Load() {
log.Info("skipping up because wasn't sleep down")
return &proto.OSLifecycleResponse{}, nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown.Store(false)
log.Info("running up after wake up")
_, err := s.Up(callerCtx, &proto.UpRequest{})
if err != nil {
log.Errorf("running up failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
log.Info("running up command executed successfully")
return &proto.OSLifecycleResponse{}, nil
}
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
s.mutex.Lock()
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, nil
}
s.mutex.Unlock()
log.Info("running down after system started sleeping")
_, err = s.Down(callerCtx, &proto.DownRequest{})
if err != nil {
log.Errorf("running down failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
s.sleepTriggeredDown.Store(true)
log.Info("running down executed successfully")
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -0,0 +1,219 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
ctx := internal.CtxInitState(context.Background())
return &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
}
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
s := newTestServer()
// sleepTriggeredDown is false by default
assert.False(t, s.sleepTriggeredDown.Load())
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnecting)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnected)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
}
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
s := newTestServer()
// Manually set the flag to simulate prior sleep down
s.sleepTriggeredDown.Store(true)
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
}
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
s := newTestServer()
// First wakeup without prior sleep - should be no-op
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
// Simulate prior sleep
s.sleepTriggeredDown.Store(true)
// First wakeup after sleep - should reset flag
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load())
// Second wakeup - should be no-op
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
s := newTestServer()
resp, err := s.handleWakeUp(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
s := newTestServer()
s.sleepTriggeredDown.Store(true)
// Even if Up fails, flag should be reset
_, _ = s.handleWakeUp(context.Background())
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
resp, err := s.handleSleep(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.handleSleep(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, s.sleepTriggeredDown.Load())
})
}
}

View File

@@ -11,8 +11,8 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
type selectRoute struct {

View File

@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon (disabled by default)
defaultJWTCacheTTL = 0
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,11 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
sleepTriggeredDown atomic.Bool
jwtCache *jwtCache
}
type oauthAuthFlow struct {
@@ -100,6 +108,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
}
@@ -373,6 +382,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
config.EnableSSHRoot = msg.EnableSSHRoot
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -487,13 +507,13 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err
}
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{
@@ -510,7 +530,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err
@@ -802,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil {
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
@@ -894,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
}
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
// todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err)
return nil, err
}
@@ -1065,12 +1087,235 @@ func (s *Server) Status(
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
return &statusResponse, nil
}
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
func (s *Server) getSSHServerState() *proto.SSHServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetSSHServerStatus()
sshServerState := &proto.SSHServerState{
Enabled: enabled,
}
for _, session := range sessions {
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
Username: session.Username,
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
})
}
return sshServerState
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
req *proto.GetPeerSSHHostKeyRequest,
) (*proto.GetPeerSSHHostKeyResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
connectClient := s.connectClient
statusRecorder := s.statusRecorder
s.mutex.Unlock()
if connectClient == nil {
return nil, errors.New("client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
peerAddress := req.GetPeerAddress()
hostKey, found := engine.GetPeerSSHKey(peerAddress)
response := &proto.GetPeerSSHHostKeyResponse{
Found: found,
}
if !found {
return response, nil
}
response.SshHostKey = hostKey
if statusRecorder == nil {
return response, nil
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
break
}
}
return response, nil
}
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
func (s *Server) getJWTCacheTTL() time.Duration {
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil || config.SSHJWTCacheTTL == nil {
return defaultJWTCacheTTL
}
seconds := *config.SSHJWTCacheTTL
if seconds == 0 {
log.Debug("SSH JWT cache disabled (configured to 0)")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
msg *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := s.getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
return &proto.RequestJWTAuthResponse{
CachedToken: cachedToken,
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
}
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
if hint == "" {
hint = profilemanager.GetLoginHint()
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
}
s.mutex.Lock()
s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
return &proto.RequestJWTAuthResponse{
VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: authInfo.UserCode,
DeviceCode: authInfo.DeviceCode,
ExpiresIn: int64(authInfo.ExpiresIn),
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
// WaitJWTToken waits for JWT authentication completion
func (s *Server) WaitJWTToken(
ctx context.Context,
req *proto.WaitJWTTokenRequest,
) (*proto.WaitJWTTokenResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
oAuthFlow := s.oauthAuthFlow.flow
authInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
}
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
}
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := s.getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
} else {
log.Debug("JWT caching disabled, not storing token")
}
s.mutex.Lock()
s.oauthAuthFlow = oauthAuthFlow{}
s.mutex.Unlock()
return &proto.WaitJWTTokenResponse{
Token: tokenInfo.GetTokenToUse(),
TokenType: tokenInfo.TokenType,
ExpiresIn: int64(tokenInfo.ExpiresIn),
}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
@@ -1136,25 +1381,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableServerRoutes := cfg.DisableServerRoutes
blockLANAccess := cfg.BlockLANAccess
enableSSHRoot := false
if cfg.EnableSSHRoot != nil {
enableSSHRoot = *cfg.EnableSSHRoot
}
enableSSHSFTP := false
if cfg.EnableSSHSFTP != nil {
enableSSHSFTP = *cfg.EnableSSHSFTP
}
enableSSHLocalPortForwarding := false
if cfg.EnableSSHLocalPortForwarding != nil {
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
}
enableSSHRemotePortForwarding := false
if cfg.EnableSSHRemotePortForwarding != nil {
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
}
disableSSHAuth := false
if cfg.DisableSSHAuth != nil {
disableSSHAuth = *cfg.DisableSSHAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
}
return &proto.GetConfigResponse{
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
BlockLanAccess: blockLANAccess,
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
BlockLanAccess: blockLANAccess,
EnableSSHRoot: enableSSHRoot,
EnableSSHSFTP: enableSSHSFTP,
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}
@@ -1385,6 +1666,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}

View File

@@ -14,13 +14,15 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -34,7 +36,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -315,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -72,6 +72,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
sshJWTCacheTTL := int32(300)
req := &proto.SetConfigRequest{
ProfileName: profName,
@@ -102,6 +103,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
}
_, err = s.SetConfig(ctx, req)
@@ -146,6 +148,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
require.NotNil(t, cfg.SSHJWTCacheTTL)
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
verifyAllFieldsCovered(t, req)
}
@@ -167,30 +171,36 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
"EnableSSHRoot": true,
"EnableSSHSFTP": true,
"EnableSSHLocalPortForwarding": true,
"EnableSSHRemotePortForwarding": true,
"DisableSSHAuth": true,
"SshJWTCacheTTL": true,
}
val := reflect.ValueOf(req).Elem()
@@ -221,29 +231,35 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// Map of CLI flag names to their corresponding SetConfigRequest field names.
// This map must be updated when adding new config-related CLI flags.
flagToField := map[string]string{
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
"enable-ssh-root": "EnableSSHRoot",
"enable-ssh-sftp": "EnableSSHSFTP",
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
"disable-ssh-auth": "DisableSSHAuth",
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).

View File

@@ -6,9 +6,11 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&nftables.ShutdownState{})
mgr.RegisterState(&iptables.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
}

View File

@@ -1,118 +0,0 @@
//go:build !js
package ssh
import (
"fmt"
"net"
"os"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
// Client wraps crypto/ssh Client to simplify usage
type Client struct {
client *ssh.Client
}
// Close closes the wrapped SSH Client
func (c *Client) Close() error {
return c.client.Close()
}
// OpenTerminal starts an interactive terminal session with the remote SSH server
func (c *Client) OpenTerminal() error {
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("failed to open new session: %v", err)
}
defer func() {
err := session.Close()
if err != nil {
return
}
}()
fd := int(os.Stdout.Fd())
state, err := term.MakeRaw(fd)
if err != nil {
return fmt.Errorf("failed to run raw terminal: %s", err)
}
defer func() {
err := term.Restore(fd, state)
if err != nil {
return
}
}()
w, h, err := term.GetSize(fd)
if err != nil {
return fmt.Errorf("terminal get size: %s", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("failed requesting pty session with xterm: %s", err)
}
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
if err := session.Shell(); err != nil {
return fmt.Errorf("failed to start login shell on the remote host: %s", err)
}
if err := session.Wait(); err != nil {
if e, ok := err.(*ssh.ExitError); ok {
if e.ExitStatus() == 130 {
return nil
}
}
return fmt.Errorf("failed running SSH session: %s", err)
}
return nil
}
// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
config := &ssh.ClientConfig{
User: user,
Timeout: 5 * time.Second,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
}
return Dial("tcp", addr, config)
}
// Dial connects to the remote SSH server.
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
client, err := ssh.Dial(network, addr, config)
if err != nil {
return nil, err
}
return &Client{
client: client,
}, nil
}

Some files were not shown because too many files have changed in this diff Show More