Compare commits

..

91 Commits

Author SHA1 Message Date
Viktor Liu
24b66fb406 Translate usernames to UPN format for domain login 2025-11-05 22:27:08 +01:00
Viktor Liu
9378b6b0a3 Merge branch 'ssh-rewrite' into move-licensed-code 2025-11-05 16:09:03 +01:00
Viktor Liu
3779a3385f Fix tests 2025-11-05 13:06:54 +01:00
Viktor Liu
b5d75ad9c4 Go fmt everything 2025-11-05 12:59:36 +01:00
Viktor Liu
8db91abfdf Merge branch 'main' into ssh-rewrite 2025-11-05 12:44:17 +01:00
Viktor Liu
6f817cad6d Remove duplicate code 2025-11-03 13:47:33 +01:00
Viktor Liu
e3bb8c1b7b Merge branch 'main' into ssh-rewrite 2025-11-03 13:43:29 +01:00
Viktor Liu
107066fa3d Merge branch 'main' into ssh-rewrite 2025-10-28 22:08:46 +01:00
Viktor Liu
a7a85d4dc8 Fix tests 2025-10-28 21:11:45 +01:00
Viktor Liu
576b4a779c Log shell 2025-10-28 18:15:53 +01:00
Viktor Liu
e6854dfd99 Improve session logging 2025-10-28 17:57:59 +01:00
Viktor Liu
6f14134988 Merge branch 'main' into ssh-rewrite 2025-10-28 16:50:23 +01:00
Viktor Liu
4fd64379da Move client-imported GPL code to separate package 2025-10-23 23:52:44 +02:00
Viktor Liu
c20202a6c3 Add new flags to test 2025-10-17 16:15:05 +02:00
Viktor Liu
4386a21956 Merge branch 'main' into ssh-rewrite 2025-10-17 15:34:36 +02:00
Zoltan Papp
5882daf5d9 Force relay connection, do not waste signaling resources on ICE connection (#4628) 2025-10-13 11:02:21 +02:00
Viktor Liu
11d71e6e22 Ignore default log file 2025-10-10 16:21:39 +02:00
Viktor Liu
4dadcfd9bd Remove client.log check 2025-10-10 16:17:46 +02:00
Viktor Liu
34b55c600e Log errors on debug 2025-10-10 16:11:13 +02:00
Viktor Liu
316c0afa9a Remove unused arg 2025-10-10 11:08:34 +02:00
Viktor Liu
cf97799db8 Fix test 2025-10-10 10:23:45 +02:00
Viktor Liu
4d297205c3 Fix test build 2025-10-09 17:26:25 +02:00
Viktor Liu
559f6aeeaf Improve logging 2025-10-08 18:54:56 +02:00
Viktor Liu
7216c201da Log priv check errors 2025-10-08 18:46:02 +02:00
Viktor Liu
4d89d0f115 Remove unused code 2025-10-08 18:39:41 +02:00
Viktor Liu
610c880ec9 Fix missing jwt config passed to peers 2025-10-08 16:47:11 +02:00
Viktor Liu
19adcb5f63 Merge branch 'main' into ssh-rewrite 2025-10-08 12:40:07 +02:00
Viktor Liu
f3d31698da Skip some auth tests on windows that are already covered 2025-10-07 23:39:01 +02:00
Viktor Liu
d9efe4e944 Add ssh authenatication with jwt (#4550) 2025-10-07 23:38:27 +02:00
Viktor Liu
7e0bbaaa3c Merge branch 'main' into ssh-rewrite 2025-10-07 09:41:07 +02:00
Viktor Liu
b3c7b3c7b2 Fix js build 2025-10-02 15:59:17 +02:00
Viktor Liu
66483ab48d Merge branch 'main' into ssh-rewrite 2025-10-02 15:53:12 +02:00
Viktor Liu
5272fc2b18 Merge branch 'main' into ssh-rewrite 2025-09-25 11:12:47 +02:00
Viktor Liu
4c53372815 Add missing flags 2025-08-27 09:59:12 +02:00
Viktor Liu
79d28b71ee Improve forwarding cancellation 2025-08-26 22:22:15 +02:00
Viktor Liu
77a352763d Fix button style 2025-08-26 21:19:04 +02:00
Viktor Liu
cdd5c6c005 Address review 2025-08-26 21:01:55 +02:00
Viktor Liu
b1a9242c98 Fix merge commit changes 2025-08-26 20:43:29 +02:00
Viktor Liu
b43ef4f17b Merge branch 'main' into ssh-rewrite 2025-08-26 20:09:47 +02:00
Viktor Liu
758a97c352 Generate ssh_config independently of ssh server 2025-07-14 22:02:41 +02:00
Viktor Liu
d93b7c2f38 Fix known hosts entries 2025-07-14 21:41:59 +02:00
Viktor Liu
fa893aa0a4 Fix build 2025-07-12 00:49:08 +02:00
Viktor Liu
ac7120871b Fix proto 2025-07-12 00:11:31 +02:00
Viktor Liu
9a7daa132e Fix client ssh file 2025-07-11 22:08:28 +02:00
Viktor Liu
cdded8c22e Merge branch 'main' into ssh-rewrite 2025-07-11 22:05:12 +02:00
Viktor Liu
e4e0b8fff9 Remove empty file 2025-07-04 17:09:54 +02:00
Viktor Liu
a4b067553d Merge branch 'main' into ssh-rewrite 2025-07-04 16:53:54 +02:00
Viktor Liu
088956645f Fix username validation and skip ci tests properly 2025-07-03 15:36:42 +02:00
Viktor Liu
aa30b7afe8 More windows tests 2025-07-03 14:11:20 +02:00
Viktor Liu
f1bb4d2ac3 Fix more Windows tests 2025-07-03 13:35:53 +02:00
Viktor Liu
982841e25b Test up tests users if none are available on CI 2025-07-03 12:33:31 +02:00
Viktor Liu
a476b8d12f Fix more windows tests 2025-07-03 11:26:04 +02:00
Viktor Liu
a21f924b26 Fix some windows tests 2025-07-03 10:20:16 +02:00
Viktor Liu
9e51d2e8fb Fix lint and sonar 2025-07-03 09:58:25 +02:00
Viktor Liu
3e490d974c Remove duplicated code 2025-07-03 03:40:27 +02:00
Viktor Liu
04bb314426 Allow sftp same user switching on windows 2025-07-03 02:19:12 +02:00
Viktor Liu
6e15882c11 Fix tests and windows username validation 2025-07-03 01:58:15 +02:00
Viktor Liu
76f9e11b29 Fix tests 2025-07-03 01:07:58 +02:00
Viktor Liu
612de2c784 Remove socketfilter temporarily 2025-07-02 22:00:10 +02:00
Viktor Liu
1fdde66c31 More lint 2025-07-02 21:55:25 +02:00
Viktor Liu
5970591d24 Fix lint 2025-07-02 21:32:39 +02:00
Viktor Liu
0d5408baec Fix lint 2025-07-02 21:04:58 +02:00
Viktor Liu
96084e3a02 Reduce complexity 2025-07-02 20:43:17 +02:00
Viktor Liu
4bbca28eb6 Fix lint 2025-07-02 20:23:23 +02:00
Viktor Liu
279b77dee0 Bump sftp 2025-07-02 19:42:57 +02:00
Viktor Liu
9d1554f9f7 Complete overhaul 2025-07-02 19:35:19 +02:00
Viktor Liu
f56075ca15 Tidy mod 2025-07-02 19:34:36 +02:00
Viktor Liu
6ed846ae29 Refactor ssh server and client 2025-07-02 19:34:36 +02:00
Viktor Liu
520f2cfdb4 Remove implicit inbound ssh firewall rules and change default port 2025-07-02 19:34:32 +02:00
Viktor Liu
0f79a8942d Fix route notificaiton 2025-07-02 17:24:14 +02:00
Viktor Liu
5299e9fda3 Merge branch 'main' into android-dns-routes 2025-07-02 15:23:14 +02:00
Viktor Liu
11bdf5b3a5 Use r 2025-06-26 15:41:56 +02:00
Viktor Liu
5fc95d4a0c Display domains properly 2025-06-26 15:36:14 +02:00
Viktor Liu
c7884039b8 Revert "Fix errorf"
This reverts commit 26fc32f1be.
2025-06-25 15:17:31 +02:00
Viktor Liu
26fc32f1be Fix errorf 2025-06-25 15:03:55 +02:00
Viktor Liu
a79cb1c11b Merge branch 'main' into android-dns-routes 2025-06-18 17:27:13 +02:00
Viktor Liu
306d75fe1a Set up fake ip route only if the dns feature flag is enabled 2025-06-17 22:29:13 +02:00
Viktor Liu
9468e69c8c Extract static error 2025-06-17 21:47:05 +02:00
Viktor Liu
f51ce7cee5 Remove nil checks 2025-06-17 21:41:58 +02:00
Viktor Liu
d47c6b624e Fix spelling 2025-06-17 20:02:52 +02:00
Viktor Liu
471f90e8db Rename methods 2025-06-17 15:52:34 +02:00
Viktor Liu
1a3b04d2fe Swap tracking and nat order 2025-06-17 15:45:22 +02:00
Viktor Liu
51b9e93eb9 Merge branch 'main' into android-dns-routes 2025-06-17 15:12:05 +02:00
Viktor Liu
2952669e97 Fix lint 2025-06-17 14:16:59 +02:00
Viktor Liu
7cd44a9a3c Improve nat perf 2025-06-17 13:55:57 +02:00
Viktor Liu
8684981b57 Add tests 2025-06-17 13:41:06 +02:00
Viktor Liu
8e94d85d14 Rename test files 2025-06-17 12:46:17 +02:00
Viktor Liu
631b77dc3c Remove some allocations 2025-06-17 12:44:52 +02:00
Viktor Liu
50ac3d437e Fix lint issues 2025-06-17 03:07:28 +02:00
Viktor Liu
49bbd90557 Fix test 2025-06-17 02:57:15 +02:00
Viktor Liu
bb74e903cd Implement dns routes for Android 2025-06-17 02:48:13 +02:00
207 changed files with 4959 additions and 16155 deletions

View File

@@ -3,108 +3,40 @@ name: Check License Dependencies
on:
push:
branches: [ main ]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs:
check-internal-dependencies:
name: Check Internal AGPL Dependencies
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses
check-dependencies:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
- name: Check for problematic license dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo "✅ All license dependencies are clean"
fi

View File

@@ -15,14 +15,13 @@ jobs:
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v4

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -30,7 +30,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Get Go environment
@@ -106,15 +106,15 @@ jobs:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -151,15 +151,15 @@ jobs:
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.24-alpine \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -220,15 +220,15 @@ jobs:
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -270,15 +270,15 @@ jobs:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -321,15 +321,15 @@ jobs:
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -408,16 +408,15 @@ jobs:
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -498,15 +497,15 @@ jobs:
-p 9090:9090 \
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -562,15 +561,15 @@ jobs:
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -24,7 +24,7 @@ jobs:
uses: actions/setup-go@v5
id: go
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Get Go environment

View File

@@ -46,7 +46,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
@@ -39,7 +39,7 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build android netbird lib
@@ -56,9 +56,9 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib

View File

@@ -20,7 +20,7 @@ concurrency:
jobs:
release:
runs-on: ubuntu-latest-m
runs-on: ubuntu-22.04
env:
flags: ""
steps:
@@ -40,7 +40,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -136,7 +136,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -200,7 +200,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v4

View File

@@ -67,13 +67,10 @@ jobs:
- name: Install curl
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
@@ -83,6 +80,9 @@ jobs:
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Setup MySQL privileges
if: matrix.store == 'mysql'
run: |

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
@@ -45,7 +45,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
@@ -60,8 +60,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

View File

@@ -4,13 +4,10 @@ package android
import (
"context"
"fmt"
"os"
"slices"
"sync"
"golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
@@ -19,13 +16,10 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile
@@ -68,18 +62,17 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: platformFiles.ConfigurationFilePath(),
cfgFile: cfgFile,
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -87,12 +80,11 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
@@ -115,7 +107,7 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.login(urlOpener, isAndroidTV)
err = auth.login(urlOpener)
if err != nil {
return err
}
@@ -123,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -150,7 +142,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
// Stop the internal client and free the resources
@@ -164,19 +156,6 @@ func (c *Client) Stop() {
c.ctxCancel()
}
func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
return e.RenewTun(fd)
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -198,7 +177,6 @@ func (c *Client) PeersList() *PeerInfoArray {
p.IP,
p.FQDN,
p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi
}
@@ -223,43 +201,31 @@ func (c *Client) Networks() *NetworkArray {
return nil
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
log.Error("could not get route selector")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
Name: string(id),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
@@ -287,69 +253,6 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func (c *Client) toggleRoute(command routeCommand) error {
return command.toggleRoute()
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}
engine := client.Engine()
if engine == nil {
return nil, fmt.Errorf("engine is not running")
}
manager := engine.GetRouteManager()
if manager == nil {
return nil, fmt.Errorf("could not get route manager")
}
return manager, nil
}
func (c *Client) SelectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
}
func (c *Client) DeselectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
}
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
// with its resolved IP addresses from the provided resolvedDomains map.
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
domains := NetworkDomains{}
for _, d := range route.Domains {
networkDomain := NetworkDomain{
Address: d.SafeString(),
}
if info, exists := resolvedDomains[d]; exists {
for _, prefix := range info.Prefixes {
networkDomain.addResolvedIP(prefix.Addr().String())
}
}
domains.Add(&networkDomain)
}
return domains
}
func exportEnvList(list *EnvList) {
if list == nil {
return

View File

@@ -3,7 +3,15 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -24,7 +32,7 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
Open(url string, userCode string)
Open(string)
OnLoginSuccess()
}
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -110,24 +131,26 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
err = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
}
// Login try register the client on the server
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
go func() {
err := a.login(urlOpener, isAndroidTV)
err := a.login(urlOpener)
if err != nil {
resultListener.OnError(err)
} else {
@@ -136,40 +159,48 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}()
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
func (a *Auth) login(urlOpener URLOpener) error {
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
go urlOpener.OnLoginSuccess()
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
if err != nil {
return nil, err
}
@@ -179,12 +210,24 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
go urlOpener.Open(flowInfo.VerificationURIComplete)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -1,56 +0,0 @@
//go:build android
package android
import "fmt"
type ResolvedIPs struct {
resolvedIPs []string
}
func (r *ResolvedIPs) Add(ipAddress string) {
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
}
func (r *ResolvedIPs) Get(i int) (string, error) {
if i < 0 || i >= len(r.resolvedIPs) {
return "", fmt.Errorf("%d is out of range", i)
}
return r.resolvedIPs[i], nil
}
func (r *ResolvedIPs) Size() int {
return len(r.resolvedIPs)
}
type NetworkDomain struct {
Address string
resolvedIPs ResolvedIPs
}
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
d.resolvedIPs.Add(resolvedIP)
}
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
return &d.resolvedIPs
}
type NetworkDomains struct {
domains []*NetworkDomain
}
func (n *NetworkDomains) Add(domain *NetworkDomain) {
n.domains = append(n.domains, domain)
}
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
if i < 0 || i >= len(n.domains) {
return nil, fmt.Errorf("%d is out of range", i)
}
return n.domains[i], nil
}
func (n *NetworkDomains) Size() int {
return len(n.domains)
}

View File

@@ -3,16 +3,10 @@
package android
type Network struct {
Name string
Network string
Peer string
Status string
IsSelected bool
Domains NetworkDomains
}
func (n Network) GetNetworkDomains() *NetworkDomains {
return &n.Domains
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {

View File

@@ -1,5 +1,3 @@
//go:build android
package android
// PeerInfo describe information about the peers. It designed for the UI usage
@@ -7,11 +5,6 @@ type PeerInfo struct {
IP string
FQDN string
ConnStatus string // Todo replace to enum
Routes PeerRoutes
}
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
return &p.Routes
}
// PeerInfoArray is a wrapper of []PeerInfo

View File

@@ -1,20 +0,0 @@
//go:build android
package android
import "fmt"
type PeerRoutes struct {
routes []string
}
func (p *PeerRoutes) Get(i int) (string, error) {
if i < 0 || i >= len(p.routes) {
return "", fmt.Errorf("%d is out of range", i)
}
return p.routes[i], nil
}
func (p *PeerRoutes) Size() int {
return len(p.routes)
}

View File

@@ -1,10 +0,0 @@
//go:build android
package android
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
// at their default locations due to android OS restrictions.
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
}

View File

@@ -1,67 +0,0 @@
//go:build android
package android
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/route"
)
func executeRouteToggle(id string, manager routemanager.Manager,
operationName string,
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
netID := route.NetID(id)
routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err)
}
manager.TriggerSelection(manager.GetClientRoutes())
return nil
}
type routeCommand interface {
toggleRoute() error
}
type selectRouteCommand struct {
route string
manager routemanager.Manager
}
func (s selectRouteCommand) toggleRoute() error {
routeSelector := s.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
return routeSelector.SelectRoutes(routes, true, allRoutes)
}
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
}
type deselectRouteCommand struct {
route string
manager routemanager.Manager
}
func (d deselectRouteCommand) toggleRoute() error {
routeSelector := d.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
}

View File

@@ -2,13 +2,13 @@ package cmd
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/util"
)
@@ -107,13 +106,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
@@ -249,7 +241,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -277,50 +269,54 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
needsLogin := false
err = authClient.Login(ctx, "", "")
if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = authClient.Login(ctx, setupKey, jwtToken)
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil {
return nil, err
}
@@ -332,7 +328,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -259,7 +259,6 @@ func isServiceRunning() (bool, error) {
}
const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
@@ -274,16 +273,12 @@ ManageForeignRoutingPolicyRules=no
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
log.Debug("systemd-networkd not in use, skipping configuration")
parentDir := filepath.Dir(networkdConfDir)
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)

View File

@@ -14,9 +14,7 @@ import (
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
@@ -36,7 +34,6 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var (
@@ -50,7 +47,6 @@ var (
knownHostsFile string
identityFile string
skipCachedToken bool
requestPTY bool
)
var (
@@ -60,7 +56,6 @@ var (
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
)
func init() {
@@ -70,16 +65,13 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
@@ -105,9 +97,9 @@ SSH Options:
-p, --port int Remote SSH port (default 22)
-u, --user string SSH username
--login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file
-i, --identity string Path to SSH private key file
Examples:
netbird ssh peer-hostname
@@ -115,10 +107,8 @@ Examples:
netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true,
@@ -153,10 +143,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err
cmd.Printf("Error: %v\n", err)
os.Exit(1)
}
cancel()
}()
@@ -164,10 +154,6 @@ func sshFn(cmd *cobra.Command, args []string) error {
select {
case <-sig:
cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done():
}
@@ -365,7 +351,6 @@ type sshFlags struct {
Port int
Username string
Login string
RequestPTY bool
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
@@ -388,24 +373,22 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
fs.String("user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
fs.String("known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
fs.String("config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags
}
@@ -426,10 +409,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err
return parseHostnameAndCommand(filteredArgs)
}
remaining := fs.Args()
@@ -444,7 +424,6 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
username = flags.Login
}
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
@@ -541,29 +520,10 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
// executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err)
}
return nil
@@ -575,13 +535,6 @@ func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err)
}
return nil
@@ -749,9 +702,7 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
@@ -767,11 +718,6 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
@@ -790,8 +736,7 @@ var sshDetectCmd = &cobra.Command{
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(detectLogLevel, "console"); err != nil {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
@@ -800,21 +745,15 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
port, err := strconv.Atoi(portStr)
if err != nil {
log.Debugf("invalid port %q: %v", portStr, err)
os.Exit(detection.ServerTypeRegular.ExitCode())
}
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
log.Debugf("SSH server detection failed: %v", err)
cancel()
os.Exit(detection.ServerTypeRegular.ExitCode())
}
cancel()
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -8,7 +8,6 @@ import (
"io"
"os"
"os/user"
"strings"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
@@ -52,7 +51,7 @@ func sftpMainDirect(cmd *cobra.Command) error {
if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
}
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail)
}

View File

@@ -667,51 +667,3 @@ func TestSSHCommand_ParameterIsolation(t *testing.T) {
})
}
}
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
}
if err != nil {

View File

@@ -13,10 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
@@ -88,6 +84,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
@@ -113,18 +110,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -286,13 +286,6 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error
var loginResp *proto.LoginResponse
@@ -362,18 +355,14 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
@@ -478,10 +467,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
@@ -602,11 +587,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
@@ -169,13 +168,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}

View File

@@ -1,14 +1,13 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -41,13 +40,19 @@ type aclManager struct {
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
}, nil
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return m, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
@@ -93,8 +98,8 @@ func (m *aclManager) AddPeerFiltering(
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
@@ -108,18 +113,14 @@ func (m *aclManager) AddPeerFiltering(
}}, nil
}
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
}
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
@@ -171,16 +172,11 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
@@ -194,7 +190,10 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
@@ -207,16 +206,6 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
@@ -275,19 +264,11 @@ func (m *aclManager) cleanChains() error {
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
m.ipsetStore.deleteIpset(ipsetName)
}
@@ -387,8 +368,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is 0.0.0.0
if ip.IsUnspecified() {
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
@@ -435,61 +416,3 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
ipset "github.com/lrh3321/ipset-go"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -107,6 +107,10 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return r, nil
}
@@ -228,12 +232,12 @@ func (r *router) findSets(rule []string) []string {
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
@@ -242,7 +246,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
}
func (r *router) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
@@ -911,8 +915,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
@@ -989,37 +993,3 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -91,7 +91,11 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Warnf("failed to load filter table, skipping accept rules: %v", err)
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
}
return r, nil
@@ -171,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
return nil, fmt.Errorf("unable to list tables: %v", err)
}
for _, table := range tables {
@@ -189,6 +193,8 @@ func (r *router) createContainers() error {
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
@@ -230,12 +236,9 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
@@ -247,7 +250,11 @@ func (r *router) createContainers() error {
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
return nil
@@ -688,7 +695,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() {
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
@@ -754,6 +761,8 @@ func (r *router) addPostroutingRules() {
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
@@ -830,7 +839,7 @@ func (r *router) addMSSClampingRules() error {
Exprs: exprsOut,
})
return r.conn.Flush()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
@@ -1059,7 +1068,7 @@ func (r *router) acceptFilterRulesNftables() error {
}
r.conn.InsertRule(inputRule)
return r.conn.Flush()
return nil
}
func (r *router) removeAcceptFilterRules() error {
@@ -1187,7 +1196,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("list rules: %w", err)
return fmt.Errorf(" unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"runtime"
"time"
@@ -11,6 +12,7 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -18,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
@@ -26,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -43,22 +68,25 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}))
}
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
conn, err := grpc.NewClient(
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
return nil, fmt.Errorf("dial context: %w", err)
return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil

View File

@@ -3,7 +3,6 @@
package device
import (
"fmt"
"strings"
log "github.com/sirupsen/logrus"
@@ -20,12 +19,11 @@ import (
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
// todo: review if we can eliminate the TunAdapter
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
@@ -34,19 +32,17 @@ type WGTunDevice struct {
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
renewableTun *RenewableTUN
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
renewableTun: NewRenewableTUN(),
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
}
}
@@ -69,17 +65,14 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err
}
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to create Android interface: %s", err)
return nil, err
}
t.renewableTun.AddDevice(unmonitoredTUN)
t.name = name
t.filteredDevice = newDeviceFilter(t.renewableTun)
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
@@ -111,23 +104,6 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) RenewTun(fd int) error {
if t.device == nil {
return fmt.Errorf("device not initialized")
}
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to renew Android interface: %s", err)
return err
}
t.renewableTun.AddDevice(unmonitoredTUN)
return nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement
return nil

View File

@@ -2,13 +2,6 @@
package device
import "fmt"
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}
func (t *TunNetstackDevice) RenewTun(fd int) error {
// Doesn't make sense in Android for Netstack.
return fmt.Errorf("this function has not been implemented in Netstack for Android")
}

View File

@@ -1,309 +0,0 @@
//go:build android
package device
import (
"io"
"os"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
//
// It also redirects tun.Device's Events() to a separate goroutine
// and closes it when Close is called.
//
// The WaitGroup and CloseOnce fields are used to ensure that the
// goroutine is awaited and closed only once.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel (RenewableTUN's events channel).
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
}
type RenewableTUN struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
events chan tun.Event
closed atomic.Bool
}
func NewRenewableTUN() *RenewableTUN {
r := &RenewableTUN{
devices: make([]*closeAwareDevice, 0),
mu: sync.Mutex{},
events: make(chan tun.Event, 16),
}
r.cond = sync.NewCond(&r.mu)
return r
}
func (r *RenewableTUN) File() *os.File {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// Read reads from an underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries reading from the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
for {
dev := r.peekLast()
if dev == nil {
// wait until AddDevice() signals a new device via cond.Broadcast()
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
return 0, io.EOF
}
continue
}
n, err = dev.Read(bufs, sizes, offset)
if err == nil {
return n, nil
}
// swap in progress; retry on the newest instead of returning the error
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err // propagate non-swap error
}
}
// Write writes to underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries writing to the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}
func (r *RenewableTUN) MTU() (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
func (r *RenewableTUN) Name() (string, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// Events returns a channel that is fed events from the underlying tun.Device's events channel
// once it is added.
func (r *RenewableTUN) Events() <-chan tun.Event {
return r.events
}
func (r *RenewableTUN) Close() error {
// Attempts to set the RenewableTUN closed flag to true.
// If it's already true, returns immediately.
if !r.closed.CompareAndSwap(false, true) {
return nil // already closed: idempotent
}
r.mu.Lock()
devices := r.devices
r.devices = nil
r.cond.Broadcast()
r.mu.Unlock()
var lastErr error
log.Debugf("closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
log.Debugf("error closing a device: %v", err)
lastErr = err
}
}
close(r.events)
return lastErr
}
func (r *RenewableTUN) BatchSize() int {
return 1
}
func (r *RenewableTUN) AddDevice(device tun.Device) {
r.mu.Lock()
if r.closed.Load() {
r.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(r.devices) > 0 {
toClose = r.devices[len(r.devices)-1]
}
cad := newClosableDevice(device)
cad.redirectEvents(r.events)
r.devices = []*closeAwareDevice{cad}
r.cond.Broadcast()
r.mu.Unlock()
if toClose != nil {
if err := toClose.Close(); err != nil {
log.Debugf("error closing last device: %v", err)
}
}
}
func (r *RenewableTUN) waitForDevice() bool {
r.mu.Lock()
defer r.mu.Unlock()
for len(r.devices) == 0 && !r.closed.Load() {
r.cond.Wait()
}
return !r.closed.Load()
}
func (r *RenewableTUN) peekLast() *closeAwareDevice {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.devices) == 0 {
return nil
}
return r.devices[len(r.devices)-1]
}

View File

@@ -21,6 +21,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
RenewTun(fd int) error
GetICEBind() device.EndpointManager
}

View File

@@ -24,7 +24,3 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on non-android")
}

View File

@@ -6,7 +6,6 @@ import (
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// todo: review does this function really necessary or can we merge it with iOS
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
@@ -23,9 +22,3 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}
func (w *WGIface) RenewTun(fd int) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.RenewTun(fd)
}

View File

@@ -39,7 +39,3 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on this platform")
}

View File

@@ -1,7 +1,6 @@
package iface
import (
"context"
"fmt"
"net"
"net/netip"
@@ -10,13 +9,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// keep darwin compatibility
@@ -41,7 +40,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -124,7 +123,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -167,7 +166,7 @@ func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -212,7 +211,7 @@ func TestRecreation(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -285,7 +284,7 @@ func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
wgPort := 33100
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -340,7 +339,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -410,7 +409,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -472,7 +471,7 @@ func Test_ConnectPeers(t *testing.T) {
peer2wgPort := 33200
keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -515,7 +514,7 @@ func Test_ConnectPeers(t *testing.T) {
guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet(context.Background(), nil)
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
}

View File

@@ -1,7 +1,6 @@
package udpmux
import (
"context"
"fmt"
"io"
"net"
@@ -13,9 +12,8 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
/*
@@ -201,7 +199,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}

View File

@@ -1,287 +0,0 @@
package auth
import (
"context"
"errors"
"fmt"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// Auth manages authentication operations with the management server
// The underlying management client handles connection retry and reconnection automatically
type Auth struct {
client *mgm.GrpcClient
config *profilemanager.Config
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The management client handles connection retry and reconnection automatically
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient := mgm.NewClient(mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err := mgmClient.Connect(ctx); err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
// Try PKCE flow first
_, err := a.getPKCEFlow(ctx)
if err == nil {
return true, nil
}
// Check if PKCE is not supported
if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(ctx)
if err == nil {
return true, nil
}
// Check if Device flow is also not supported
if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) {
// Neither PKCE nor Device flow is supported
return false, nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return false, err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return false, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
err = a.doMgmLogin(ctx, pubSSHKey)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login attempts to log in or register the client with the management server
// Returns custom errors from mgm package: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) error {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return fmt.Errorf("generate SSH public key: %w", err)
}
err = a.doMgmLogin(ctx, pubSSHKey)
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
return a.registerPeer(ctx, setupKey, jwtToken, pubSSHKey)
}
return err
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(ctx context.Context) (*PKCEAuthorizationFlow, error) {
protoFlow, err := a.client.GetPKCEAuthorizationFlow(ctx)
if err != nil {
if errors.Is(err, mgm.ErrNotFound) {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(ctx context.Context) (*DeviceAuthorizationFlow, error) {
protoFlow, err := a.client.GetDeviceAuthorizationFlow(ctx)
if err != nil {
if errors.Is(err, mgm.ErrNotFound) {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(ctx context.Context, pubSSHKey []byte) error {
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
_, err := a.client.Login(ctx, sysInfo, pubSSHKey, a.config.DNSLabels)
return err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) error {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return fmt.Errorf("%w: invalid setup-key or no SSO information provided: %v", mgm.ErrInvalidArgument, err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
// todo: fix error handling of validSetupKey
if err := a.client.Register(ctx, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels); err != nil {
log.Errorf("failed registering peer %v", err)
return err
}
log.Infof("peer has been successfully registered on Management Service")
return nil
}
// isPermissionDenied checks if the error is a PermissionDenied error
func isPermissionDenied(err error) bool {
return errors.Is(err, mgm.ErrPermissionDenied)
}
// isLoginNeeded checks if the error indicates login is required
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
return errors.Is(err, mgm.ErrInvalidArgument) ||
errors.Is(err, mgm.ErrPermissionDenied) ||
errors.Is(err, mgm.ErrUnauthenticated)
}
// isRegistrationNeeded checks if the error indicates peer registration is needed
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -25,56 +26,12 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -100,7 +57,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -132,11 +89,6 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -176,34 +128,9 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
if d.providerConfig.LoginHint != "" {
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
if deviceCode.VerificationURI != "" {
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
}
}
return deviceCode, err
}
func appendLoginHint(uri, loginHint string) string {
if uri == "" || loginHint == "" {
return uri
}
parsedURL, err := url.Parse(uri)
if err != nil {
log.Debugf("failed to parse verification URI for login_hint: %v", err)
return uri
}
query := parsedURL.Query()
query.Set("login_hint", loginHint)
parsedURL.RawQuery = query.Encode()
return parsedURL.String()
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)
@@ -247,22 +174,14 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,6 +12,8 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -113,19 +115,18 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -279,19 +280,18 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,6 +10,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -59,60 +60,39 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken
}
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
return authenticateWithDeviceCodeFlow(ctx, config, hint)
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config, hint)
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(ctx)
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
return pkceFlowInfo, nil
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(ctx)
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -127,9 +107,5 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
return deviceFlowInfo, nil
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -19,8 +19,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
@@ -33,67 +33,17 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig PKCEAuthProviderConfig
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
@@ -159,9 +109,6 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
if p.providerConfig.LoginHint != "" {
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
@@ -171,21 +118,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -196,7 +132,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -207,8 +143,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
go p.startServer(server, tokenChan, errChan)
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:
@@ -253,20 +189,17 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
if authErrorDesc != "" {
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
}
return nil, fmt.Errorf("authentication failed: %s", authError)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("authentication failed: Invalid state")
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("authentication failed: missing code")
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
@@ -295,7 +228,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
email, err := parseEmailFromIDToken(tokenInfo.IDToken)

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -40,7 +41,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -74,7 +74,6 @@ func (c *ConnectClient) RunOnAndroid(
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -83,7 +82,6 @@ func (c *ConnectClient) RunOnAndroid(
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
}
@@ -162,11 +160,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
// Create management client once outside retry loop
mgmClient := mgm.NewClient(c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -185,9 +178,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}()
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
if err := mgmClient.ConnectWithoutRetry(engineCtx); err != nil {
return wrapErr(fmt.Errorf("failed connecting to Management Service: %w", err))
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
@@ -200,7 +196,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
log.Debug(err)
if errors.Is(err, mgm.ErrPermissionDenied) {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
@@ -322,7 +318,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if errors.Is(err, mgm.ErrPermissionDenied) {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
}
@@ -506,6 +502,12 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -524,7 +526,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(ctx, sysInfo, pubSSHKey, config.DNSLabels)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}

View File

@@ -44,8 +44,6 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
@@ -186,20 +184,6 @@ The ip_rules.txt file contains detailed IP routing rule information:
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
DNS Configuration
The debug bundle includes platform-specific DNS configuration files:
resolv.conf (Unix systems):
- Contains DNS resolver configuration from /etc/resolv.conf
- Includes nameserver entries, search domains, and resolver options
- All IP addresses and domain names are anonymized following the same rules as other files
scutil_dns.txt (macOS only):
- Contains detailed DNS configuration from scutil --dns
- Shows DNS configuration for all network interfaces
- Includes search domains, nameservers, and DNS resolver settings
- All IP addresses and domain names are anonymized
`
const (
@@ -373,10 +357,6 @@ func (g *BundleGenerator) addSystemInfo() {
if err := g.addFirewallRules(); err != nil {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}
}
func (g *BundleGenerator) addReadme() error {

View File

@@ -1,53 +0,0 @@
//go:build darwin && !ios
package debug
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
if err := g.addScutilDNS(); err != nil {
log.Errorf("failed to add scutil DNS output: %v", err)
}
return nil
}
func (g *BundleGenerator) addScutilDNS() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "scutil", "--dns")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("execute scutil --dns: %w", err)
}
if len(bytes.TrimSpace(output)) == 0 {
return fmt.Errorf("no scutil DNS output")
}
content := string(output)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
return fmt.Errorf("add scutil DNS output to zip: %w", err)
}
return nil
}

View File

@@ -5,7 +5,3 @@ package debug
func (g *BundleGenerator) addRoutes() error {
return nil
}
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -1,16 +0,0 @@
//go:build unix && !darwin && !android
package debug
import (
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
return nil
}

View File

@@ -1,7 +0,0 @@
//go:build !unix
package debug
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -1,29 +0,0 @@
//go:build unix && !android
package debug
import (
"fmt"
"os"
"strings"
)
const resolvConfPath = "/etc/resolv.conf"
func (g *BundleGenerator) addResolvConf() error {
data, err := os.ReadFile(resolvConfPath)
if err != nil {
return fmt.Errorf("read %s: %w", resolvConfPath, err)
}
content := string(data)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
return fmt.Errorf("add resolv.conf to zip: %w", err)
}
return nil
}

View File

@@ -0,0 +1,134 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -76,9 +76,6 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
continue
}
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
@@ -109,9 +106,8 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
SearchDomainDisabled: true,
Domain: zoneName,
Records: records,
}
config.CustomZones = append(config.CustomZones, reverseZone)

View File

@@ -11,6 +11,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
const (
ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa."
)
type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error
@@ -105,9 +110,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
}
for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
MatchOnly: customZone.SearchDomainDisabled,
MatchOnly: matchOnly,
})
}

View File

@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatal(err)
}
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
newNet, err := stdnet.NewNet([]string{"utun2301"})
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err

View File

@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warn(timeoutMsg)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {

View File

@@ -255,7 +255,7 @@ func NewEngine(
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
@@ -891,7 +891,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(e.ctx, info); err != nil {
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
@@ -1192,7 +1192,6 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
//nolint
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort
@@ -1207,9 +1206,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
Domain: zone.GetDomain(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
@@ -1517,7 +1514,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.DisableSSHAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(e.ctx, info)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, false, err
}
@@ -1666,7 +1663,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy(e.ctx)
managementHealthy := e.mgmClient.IsHealthy()
log.Debugf("management health check: healthy=%t", managementHealthy)
stuns := slices.Clone(e.STUNs)
@@ -1833,18 +1830,6 @@ func (e *Engine) GetWgAddr() netip.Addr {
return e.wgInterface.Address().IP
}
func (e *Engine) RenewTun(fd int) error {
e.syncMsgMux.Lock()
wgInterface := e.wgInterface
e.syncMsgMux.Unlock()
if wgInterface == nil {
return fmt.Errorf("wireguard interface not initialized")
}
return wgInterface.RenewTun(fd)
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(
enabled bool,

View File

@@ -19,7 +19,6 @@ import (
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -235,17 +234,7 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
e.sshServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
@@ -254,10 +243,17 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
}
}
e.configureSSHServer(server)
e.sshServer = server
if err := e.setupSSHPortRedirection(); err != nil {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
return nil
}
@@ -340,16 +336,3 @@ func (e *Engine) stopSSHServer() error {
}
return nil
}
// GetSSHServerStatus returns the SSH server status and active sessions
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
e.syncMsgMux.Lock()
sshServer := e.sshServer
e.syncMsgMux.Unlock()
if sshServer == nil {
return false, nil
}
return sshServer.GetStatus()
}

View File

@@ -7,5 +7,5 @@ import (
)
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
return stdnet.NewNet(e.config.IFaceBlackList)
}

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -24,14 +25,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
@@ -110,10 +105,6 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time
}
func (m *MockWGIface) RenewTun(_ int) error {
return nil
}
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
@@ -814,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -1017,7 +1008,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
@@ -1503,8 +1494,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if err != nil {
return nil, err
}
mgmtClient := mgmt.NewClient(mgmtAddr, key, false)
if err := mgmtClient.Connect(ctx); err != nil {
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
@@ -1512,8 +1503,13 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
err = mgmtClient.Register(ctx, setupKey, "", info, nil, nil)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1526,10 +1522,9 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
wgPort := 33100 + i
testAddr := fmt.Sprintf("100.64.0.%d/24", i+1)
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: testAddr,
WgAddr: resp.PeerConfig.Address,
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
@@ -1595,6 +1590,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -1622,16 +1618,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -20,7 +20,6 @@ import (
type wgIfaceBase interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
RenewTun(fd int) error
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address

201
client/internal/login.go Normal file
View File

@@ -0,0 +1,201 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err)
}

View File

@@ -1,7 +1,6 @@
package ice
import (
"context"
"sync"
"time"
@@ -23,8 +22,6 @@ const (
iceFailedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
iceAgentCloseTimeout = 3 * time.Second
)
type ThreadSafeAgent struct {
@@ -35,28 +32,18 @@ type ThreadSafeAgent struct {
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
err = a.Agent.Close()
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceFailedTimeout := iceFailedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}

View File

@@ -3,11 +3,9 @@
package ice
import (
"context"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ctx, ifaceBlacklist)
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ifaceBlacklist)
}

View File

@@ -1,11 +1,7 @@
package ice
import (
"context"
import "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
}

View File

@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) {
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}

View File

@@ -0,0 +1,136 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -55,7 +55,6 @@ type ConfigInput struct {
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
@@ -105,7 +104,6 @@ type Config struct {
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
DisableServerRoutes bool
@@ -438,12 +436,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
@@ -730,8 +722,8 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
return config, err
}
client := mgm.NewClient(newURL.Host, key, mgmTlsEnabled)
if err := client.Connect(ctx); err != nil {
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
@@ -743,7 +735,8 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
}()
// gRPC check
if err := client.HealthCheck(ctx); err != nil {
_, err = client.GetServerPublicKey()
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}

View File

@@ -132,21 +132,3 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
return nil
}
// GetLoginHint retrieves the email from the active profile to use as login_hint.
func GetLoginHint() string {
pm := NewProfileManager()
activeProf, err := pm.GetActiveProfile()
if err != nil {
log.Debugf("failed to get active profile for login hint: %v", err)
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""
}
return profileState.Email
}

View File

@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
net, err := stdnet.NewNet(ctx, nil)
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
net, err := stdnet.NewNet(ctx, nil)
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -38,7 +39,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/version"

View File

@@ -6,7 +6,7 @@ import (
"net/netip"
"testing"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/pion/transport/v3/stdnet"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/stretchr/testify/require"
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}

View File

@@ -15,7 +15,7 @@ import (
"syscall"
"testing"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet(context.Background(), nil)
newNet, err := stdnet.NewNet()
require.NoError(t, err)
opts := iface.WGIFaceOpts{

View File

@@ -4,28 +4,17 @@
package stdnet
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strconv"
"sync"
"time"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
updateInterval = 30 * time.Second
dnsResolveTimeout = 30 * time.Second
)
var errNoSuitableAddress = errors.New("no suitable address found")
const updateInterval = 30 * time.Second
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
@@ -39,19 +28,12 @@ type Net struct {
// mu is shared between interfaces and lastUpdate
mu sync.Mutex
// ctx is the context for network operations that supports cancellation
ctx context.Context
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
n := &Net{
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
// so in android cli use pionDiscover
@@ -64,64 +46,14 @@ func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover
}
// NewNet creates a new StdNet instance.
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
func NewNet(disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
return n, n.UpdateInterfaces()
}
// resolveAddr performs DNS resolution with context support and timeout.
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return netip.AddrPort{}, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
}
if port < 0 || port > 65535 {
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
}
ipNet := "ip"
switch network {
case "tcp4", "udp4":
ipNet = "ip4"
case "tcp6", "udp6":
ipNet = "ip6"
}
if host == "" {
addr := netip.IPv4Unspecified()
if ipNet == "ip6" {
addr = netip.IPv6Unspecified()
}
return netip.AddrPortFrom(addr, uint16(port)), nil
}
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
defer cancel()
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
if err != nil {
return netip.AddrPort{}, err
}
if len(addrs) == 0 {
return netip.AddrPort{}, errNoSuitableAddress
}
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
}
// UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
@@ -205,39 +137,3 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
}
return result
}
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
switch network {
case "udp", "udp4", "udp6":
case "":
network = "udp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
}
return net.UDPAddrFromAddrPort(addrPort), nil
}
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
switch network {
case "tcp", "tcp4", "tcp6":
case "":
network = "tcp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
}
return net.TCPAddrFromAddrPort(addrPort), nil
}

File diff suppressed because one or more lines are too long

View File

@@ -1,299 +0,0 @@
package templates
import (
"html/template"
"os"
"path/filepath"
"testing"
)
func TestPKCEAuthMsgTemplate(t *testing.T) {
tests := []struct {
name string
data map[string]string
outputFile string
expectedTitle string
expectedInContent []string
notExpectedInContent []string
}{
{
name: "error_state",
data: map[string]string{
"Error": "authentication failed: invalid state",
},
outputFile: "pkce-auth-error.html",
expectedTitle: "Login Failed",
expectedInContent: []string{
"authentication failed: invalid state",
"Login Failed",
},
notExpectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird",
},
},
{
name: "success_state",
data: map[string]string{
// No error field means success
},
outputFile: "pkce-auth-success.html",
expectedTitle: "Login Successful",
expectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird. You can now close this window.",
},
notExpectedInContent: []string{
"Login Failed",
},
},
{
name: "error_state_timeout",
data: map[string]string{
"Error": "authentication timeout: request expired after 5 minutes",
},
outputFile: "pkce-auth-timeout.html",
expectedTitle: "Login Failed",
expectedInContent: []string{
"authentication timeout: request expired after 5 minutes",
"Login Failed",
},
notExpectedInContent: []string{
"Login Successful",
"Your device is now registered and logged in to NetBird",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Parse the template
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
// Create temp directory for this test
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, tt.outputFile)
// Create output file
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
// Execute the template
if err := tmpl.Execute(file, tt.data); err != nil {
file.Close()
t.Fatalf("Failed to execute template: %v", err)
}
file.Close()
t.Logf("Generated test output: %s", outputPath)
// Read the generated file
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
contentStr := string(content)
// Verify file has content
if len(contentStr) == 0 {
t.Error("Output file is empty")
}
// Verify basic HTML structure
basicElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"NetBird",
}
for _, elem := range basicElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
// Verify expected title
if !contains(contentStr, tt.expectedTitle) {
t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle)
}
// Verify expected content is present
for _, expected := range tt.expectedInContent {
if !contains(contentStr, expected) {
t.Errorf("Expected HTML to contain '%s', but it was not found", expected)
}
}
// Verify unexpected content is not present
for _, notExpected := range tt.notExpectedInContent {
if contains(contentStr, notExpected) {
t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected)
}
}
})
}
}
func TestPKCEAuthMsgTemplateValidation(t *testing.T) {
// Test that the template can be parsed without errors
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Template parsing failed: %v", err)
}
// Test with empty data
t.Run("empty_data", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "empty-data.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
if err := tmpl.Execute(file, nil); err != nil {
t.Errorf("Template execution with nil data failed: %v", err)
}
})
// Test with error data
t.Run("with_error", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "with-error.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
data := map[string]string{
"Error": "test error message",
}
if err := tmpl.Execute(file, data); err != nil {
t.Errorf("Template execution with error data failed: %v", err)
}
})
}
func TestPKCEAuthMsgTemplateContent(t *testing.T) {
// Test that the template contains expected elements
tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl)
if err != nil {
t.Fatalf("Template parsing failed: %v", err)
}
t.Run("success_content", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "success.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
data := map[string]string{}
if err := tmpl.Execute(file, data); err != nil {
t.Fatalf("Template execution failed: %v", err)
}
// Read the file and verify it contains expected content
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
// Check for success indicators
contentStr := string(content)
if len(contentStr) == 0 {
t.Error("Generated HTML is empty")
}
// Basic HTML structure checks
requiredElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"Login Successful",
"NetBird",
}
for _, elem := range requiredElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
})
t.Run("error_content", func(t *testing.T) {
tempDir := t.TempDir()
outputPath := filepath.Join(tempDir, "error.html")
file, err := os.Create(outputPath)
if err != nil {
t.Fatalf("Failed to create output file: %v", err)
}
defer file.Close()
errorMsg := "test error message"
data := map[string]string{
"Error": errorMsg,
}
if err := tmpl.Execute(file, data); err != nil {
t.Fatalf("Template execution failed: %v", err)
}
// Read the file and verify it contains expected content
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
// Check for error indicators
contentStr := string(content)
if len(contentStr) == 0 {
t.Error("Generated HTML is empty")
}
// Basic HTML structure checks
requiredElements := []string{
"<!DOCTYPE html>",
"<html",
"<head>",
"<body>",
"Login Failed",
errorMsg,
}
for _, elem := range requiredElements {
if !contains(contentStr, elem) {
t.Errorf("Expected HTML to contain '%s', but it was not found", elem)
}
}
})
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -208,13 +208,7 @@ func (c *Client) IsLoginRequired() bool {
ConfigPath: c.cfgFile,
})
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
return true // Assume login is required if we can't create auth client
}
defer authClient.Close()
needsLogin, _ := authClient.IsLoginRequired(ctx)
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
return needsLogin
}
@@ -234,7 +228,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile,
})
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
if err != nil {
return err.Error()
}
@@ -246,17 +240,15 @@ func (c *Client) LoginForMobile() string {
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return
}
jwtToken := tokenInfo.GetTokenToUse()
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
return
}
defer authClient.Close()
_ = authClient.Login(ctx, "", jwtToken)
_ = internal.Login(ctx, cfg, "", jwtToken)
c.loginComplete = true
}()

View File

@@ -3,8 +3,15 @@ package NetBirdSDK
import (
"context"
"fmt"
"time"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -64,21 +71,30 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -87,31 +103,32 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
err = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
}
func (a *Auth) Login() error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
@@ -119,10 +136,25 @@ func (a *Auth) Login() error {
return fmt.Errorf("Not authenticated")
}
err = authClient.Login(a.ctx, "", jwtToken)
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -90,7 +90,7 @@ service DaemonService {
// RequestJWTAuth initiates JWT authentication flow for SSH
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
}
@@ -168,15 +168,11 @@ message LoginRequest {
optional int64 mtu = 32;
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
optional bool enableSSHRoot = 34;
optional bool enableSSHSFTP = 35;
optional bool enableSSHLocalPortForwarding = 36;
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool enableSSHRoot = 33;
optional bool enableSSHSFTP = 34;
optional bool enableSSHLocalPortForwarding = 35;
optional bool enableSSHRemotePortForwarding = 36;
optional bool disableSSHAuth = 37;
}
message LoginResponse {
@@ -206,7 +202,7 @@ message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3;
optional bool waitForReady = 3;
}
message StatusResponse{
@@ -281,8 +277,6 @@ message GetConfigResponse {
bool enableSSHRemotePortForwarding = 23;
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
}
// PeerState contains the latest state of a peer
@@ -346,20 +340,6 @@ message NSGroupState {
string error = 4;
}
// SSHSessionInfo contains information about an active SSH session
message SSHSessionInfo {
string username = 1;
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
}
// SSHServerState contains the latest state of the SSH server
message SSHServerState {
bool enabled = 1;
repeated SSHSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -373,7 +353,6 @@ message FullStatus {
repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
}
// Networks
@@ -640,10 +619,9 @@ message SetConfigRequest {
optional bool enableSSHRoot = 29;
optional bool enableSSHSFTP = 30;
optional bool enableSSHLocalPortForwarding = 31;
optional bool enableSSHRemotePortForwarding = 32;
optional bool enableSSHLocalPortForward = 31;
optional bool enableSSHRemotePortForward = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
}
message SetConfigResponse{}
@@ -716,8 +694,6 @@ message GetPeerSSHHostKeyResponse {
// RequestJWTAuthRequest for initiating JWT authentication flow
message RequestJWTAuthRequest {
// hint for OIDC login_hint parameter (typically email address)
optional string hint = 1;
}
// RequestJWTAuthResponse contains authentication flow information

View File

@@ -37,18 +37,13 @@ func (c *jwtCache) store(token string, maxAge time.Duration) {
c.expiresAt = time.Now().Add(maxAge)
var timer *time.Timer
timer = time.AfterFunc(maxAge, func() {
c.timer = time.AfterFunc(maxAge, func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.timer != timer {
return
}
c.cleanup()
c.timer = nil
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
})
c.timer = timer
}
func (c *jwtCache) get() (string, bool) {
@@ -75,5 +70,4 @@ func (c *jwtCache) cleanup() {
if c.enclave != nil {
c.enclave = nil
}
c.expiresAt = time.Time{}
}

View File

@@ -46,8 +46,8 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon (disabled by default)
defaultJWTCacheTTL = 0
// JWT token cache TTL for the client daemon
defaultJWTCacheTTL = 5 * time.Minute
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
@@ -278,22 +278,14 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
if err != nil {
log.Errorf("failed to create auth client: %v", err)
return internal.StatusLoginFailed, err
}
defer authClient.Close()
var status internal.StatusType
err = authClient.Login(ctx, setupKey, jwtToken)
err := internal.Login(ctx, s.config, setupKey, jwtToken)
if err != nil {
// Check if it's an authentication error (permission denied, invalid credentials, unauthenticated)
if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) {
log.Warnf("authentication failed: %v", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Warnf("failed login: %v", err)
status = internal.StatusNeedsLogin
} else {
log.Errorf("login failed: %v", err)
log.Errorf("failed login: %v", err)
status = internal.StatusLoginFailed
}
return status, err
@@ -389,15 +381,11 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.BlockInbound = msg.BlockInbound
config.EnableSSHRoot = msg.EnableSSHRoot
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -508,11 +496,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting)
if msg.SetupKey == "" {
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint)
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err
@@ -614,7 +598,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel()
}
waitCTX, cancel := context.WithCancel(ctx)
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
s.mutex.Lock()
@@ -1009,8 +994,8 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
}
mgmTlsEnabled := config.ManagementURL.Scheme == "https"
mgmClient := mgm.NewClient(config.ManagementURL.Host, key, mgmTlsEnabled)
if err := mgmClient.Connect(ctx); err != nil {
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled)
if err != nil {
return fmt.Errorf("connect to management server: %w", err)
}
defer func() {
@@ -1019,7 +1004,7 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
}
}()
return mgmClient.Logout(ctx)
return mgmClient.Logout()
}
// Status returns the daemon status
@@ -1089,47 +1074,12 @@ func (s *Server) Status(
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
return &statusResponse, nil
}
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
func (s *Server) getSSHServerState() *proto.SSHServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetSSHServerStatus()
sshServerState := &proto.SSHServerState{
Enabled: enabled,
}
for _, session := range sessions {
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
Username: session.Username,
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
})
}
return sshServerState
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1182,31 +1132,35 @@ func (s *Server) GetPeerSSHHostKey(
return response, nil
}
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
func (s *Server) getJWTCacheTTL() time.Duration {
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil || config.SSHJWTCacheTTL == nil {
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
// NB_SSH_JWT_CACHE_TTL=0 disables caching
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
func getJWTCacheTTL() time.Duration {
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
if envValue == "" {
return defaultJWTCacheTTL
}
seconds, err := strconv.Atoi(envValue)
if err != nil {
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
return defaultJWTCacheTTL
}
seconds := *config.SSHJWTCacheTTL
if seconds == 0 {
log.Debug("SSH JWT cache disabled (configured to 0)")
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
msg *proto.RequestJWTAuthRequest,
_ *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
@@ -1220,7 +1174,7 @@ func (s *Server) RequestJWTAuth(
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := s.getJWTCacheTTL()
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
@@ -1232,17 +1186,8 @@ func (s *Server) RequestJWTAuth(
}
}
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
if hint == "" {
hint = profilemanager.GetLoginHint()
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
@@ -1293,7 +1238,7 @@ func (s *Server) WaitJWTToken(
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := s.getJWTCacheTTL()
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
@@ -1384,33 +1329,28 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
blockLANAccess := cfg.BlockLANAccess
enableSSHRoot := false
if cfg.EnableSSHRoot != nil {
enableSSHRoot = *cfg.EnableSSHRoot
if s.config.EnableSSHRoot != nil {
enableSSHRoot = *s.config.EnableSSHRoot
}
enableSSHSFTP := false
if cfg.EnableSSHSFTP != nil {
enableSSHSFTP = *cfg.EnableSSHSFTP
if s.config.EnableSSHSFTP != nil {
enableSSHSFTP = *s.config.EnableSSHSFTP
}
enableSSHLocalPortForwarding := false
if cfg.EnableSSHLocalPortForwarding != nil {
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
if s.config.EnableSSHLocalPortForwarding != nil {
enableSSHLocalPortForwarding = *s.config.EnableSSHLocalPortForwarding
}
enableSSHRemotePortForwarding := false
if cfg.EnableSSHRemotePortForwarding != nil {
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
if s.config.EnableSSHRemotePortForwarding != nil {
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
}
disableSSHAuth := false
if cfg.DisableSSHAuth != nil {
disableSSHAuth = *cfg.DisableSSHAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
if s.config.DisableSSHAuth != nil {
disableSSHAuth = *s.config.DisableSSHAuth
}
return &proto.GetConfigResponse{
@@ -1437,7 +1377,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}

View File

@@ -15,10 +15,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
@@ -294,6 +290,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -314,16 +311,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -72,7 +72,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
sshJWTCacheTTL := int32(300)
req := &proto.SetConfigRequest{
ProfileName: profName,
@@ -103,7 +102,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
}
_, err = s.SetConfig(ctx, req)
@@ -148,8 +146,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
require.NotNil(t, cfg.SSHJWTCacheTTL)
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
verifyAllFieldsCovered(t, req)
}
@@ -171,36 +167,35 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
"EnableSSHRoot": true,
"EnableSSHSFTP": true,
"EnableSSHLocalPortForwarding": true,
"EnableSSHRemotePortForwarding": true,
"DisableSSHAuth": true,
"SshJWTCacheTTL": true,
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
"EnableSSHRoot": true,
"EnableSSHSFTP": true,
"EnableSSHLocalPortForward": true,
"EnableSSHRemotePortForward": true,
"DisableSSHAuth": true,
}
val := reflect.ValueOf(req).Elem()
@@ -256,10 +251,9 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"mtu": "Mtu",
"enable-ssh-root": "EnableSSHRoot",
"enable-ssh-sftp": "EnableSSHSFTP",
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForward",
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForward",
"disable-ssh-auth": "DisableSSHAuth",
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).

View File

@@ -20,7 +20,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
@@ -215,7 +214,7 @@ func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) erro
}
}
// handleCommandError processes command execution errors
// handleCommandError processes command execution errors, treating exit codes as normal
func (c *Client) handleCommandError(err error) error {
if err == nil {
return nil
@@ -223,11 +222,11 @@ func (c *Client) handleCommandError(err error) error {
var e *ssh.ExitError
var em *ssh.ExitMissingError
if errors.As(err, &e) || errors.As(err, &em) {
return err
if !errors.As(err, &e) && !errors.As(err, &em) {
return fmt.Errorf("execute command: %w", err)
}
return fmt.Errorf("execute command: %w", err)
return nil
}
// setupContextCancellation sets up context cancellation for a session
@@ -282,12 +281,6 @@ type DialOptions struct {
// Dial connects to the given ssh server with specified options
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
opts.DaemonAddr = daemonAddr
hostKeyCallback, err := createHostKeyCallback(opts)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
@@ -307,6 +300,11 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
config.Auth = append(config.Auth, authMethod)
}
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
}
@@ -343,13 +341,10 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
}
detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
return nil, fmt.Errorf("SSH server detection: %w", err)
return nil, fmt.Errorf("SSH server detection failed: %w", err)
}
if !serverType.RequiresJWT() {
@@ -370,8 +365,6 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
// requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
hint := profilemanager.GetLoginHint()
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return "", fmt.Errorf("connect to daemon: %w", err)
@@ -379,7 +372,7 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
}
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
@@ -471,7 +464,7 @@ func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicK
return nil
}
}
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}
func getKnownHostsFilesList(knownHostsFile string) []string {
@@ -515,7 +508,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
go func() {
defer func() {
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err := localListener.Close(); err != nil {
log.Debugf("local listener close error: %v", err)
}
}()
@@ -533,9 +526,6 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
}()
<-ctx.Done()
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("local listener close error: %v", err)
}
return ctx.Err()
}

View File

@@ -137,10 +137,10 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
const numClients = 3
clients := make([]*Client, numClients)
currentUser := testutil.GetTestUsername(t)
for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
InsecureSkipVerify: true,
})
cancel()

View File

@@ -15,26 +15,17 @@ import (
)
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
stdinFd := int(os.Stdin.Fd())
fd := int(os.Stdout.Fd())
if !term.IsTerminal(stdinFd) {
if !term.IsTerminal(fd) {
return c.setupNonTerminalMode(ctx, session)
}
fd := int(os.Stdin.Fd())
state, err := term.MakeRaw(fd)
if err != nil {
return c.setupNonTerminalMode(ctx, session)
}
if err := c.setupTerminal(session, fd); err != nil {
if restoreErr := term.Restore(fd, state); restoreErr != nil {
log.Debugf("restore terminal state: %v", restoreErr)
}
return err
}
c.terminalState = state
c.terminalFd = fd
@@ -64,10 +55,27 @@ func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) er
}
}()
return nil
return c.setupTerminal(session, fd)
}
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
w, h := 80, 24
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("request pty: %w", err)
}
return nil
}

View File

@@ -62,10 +62,11 @@ func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) erro
if err := c.saveWindowsConsoleState(); err != nil {
var consoleErr *ConsoleUnavailableError
if errors.As(err, &consoleErr) {
log.Debugf("console unavailable, not requesting PTY: %v", err)
return nil
log.Debugf("console unavailable, continuing with defaults: %v", err)
c.terminalFd = 0
} else {
return fmt.Errorf("save console state: %w", err)
}
return fmt.Errorf("save console state: %w", err)
}
if err := c.enableWindowsVirtualTerminal(); err != nil {
@@ -104,14 +105,7 @@ func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) erro
ssh.VREPRINT: 18, // Ctrl+R
}
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
log.Debugf("restore Windows console state: %v", restoreErr)
}
return fmt.Errorf("request pty: %w", err)
}
return nil
return session.RequestPty("xterm-256color", h, w, modes)
}
func (c *Client) saveWindowsConsoleState() error {

View File

@@ -68,12 +68,8 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
req := &proto.RequestJWTAuthRequest{}
if hint != "" {
req.Hint = &hint
}
authResponse, err := client.RequestJWTAuth(ctx, req)
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
if err != nil {
return "", fmt.Errorf("request JWT auth: %w", err)
}

View File

@@ -189,7 +189,12 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
if runtime.GOOS == "windows" {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
} else {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
}
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"

View File

@@ -3,7 +3,6 @@ package detection
import (
"bufio"
"context"
"fmt"
"net"
"strconv"
"strings"
@@ -20,8 +19,8 @@ const (
// JWTRequiredMarker is appended to responses when JWT is required
JWTRequiredMarker = "NetBird-JWT-Required"
// DefaultTimeout is the default timeout for SSH server detection
DefaultTimeout = 5 * time.Second
// Timeout is the timeout for SSH server detection
Timeout = 5 * time.Second
)
type ServerType string
@@ -62,20 +61,21 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
if err != nil {
return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err)
log.Debugf("SSH connection failed for detection: %v", err)
return ServerTypeRegular, nil
}
defer conn.Close()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetReadDeadline(deadline); err != nil {
return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err)
}
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
log.Debugf("set read deadline: %v", err)
return ServerTypeRegular, nil
}
reader := bufio.NewReader(conn)
serverBanner, err := reader.ReadString('\n')
if err != nil {
return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err)
log.Debugf("read SSH banner: %v", err)
return ServerTypeRegular, nil
}
serverBanner = strings.TrimSpace(serverBanner)

View File

@@ -18,7 +18,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
@@ -39,7 +38,6 @@ type SSHProxy struct {
targetHost string
targetPort int
stderr io.Writer
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
}
@@ -55,22 +53,12 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil
}
func (p *SSHProxy) Close() error {
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *SSHProxy) Connect(ctx context.Context) error {
hint := profilemanager.GetLoginHint()
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
@@ -156,7 +144,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
if len(session.Command()) > 0 {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}
return
}
@@ -167,16 +154,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
}
if err := serverSession.Wait(); err != nil {
log.Debugf("session wait: %v", err)
p.handleProxyExitCode(session, err)
}
}
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
}
}
}

View File

@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
execCmd, err := s.createCommand(privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
@@ -42,59 +42,31 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
return
}
if !hasPty {
if s.executeCommand(logger, session, execCmd, cleanup) {
logger.Debugf("%s execution completed", commandType)
}
return
}
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
if s.executeCommand(logger, session, execCmd) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
}
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
cmd, err = s.createExecutorCommand(session, localUser, hasPty)
}
if err != nil {
return nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, func() {}, nil
return cmd, nil
}
// executeCommand executes the command and handles I/O and exit codes
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
defer cleanup()
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
s.setupProcessGroup(execCmd)
stdinPipe, err := execCmd.StdinPipe()
@@ -107,7 +79,7 @@ func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd
}
execCmd.Stdout = session
execCmd.Stderr = session.Stderr()
execCmd.Stderr = session
if execCmd.Dir != "" {
if _, err := os.Stat(execCmd.Dir); err != nil {

View File

@@ -3,13 +3,11 @@
package server
import (
"context"
"errors"
"os/exec"
"os/user"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
@@ -20,8 +18,8 @@ func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
@@ -34,19 +32,5 @@ func (s *Server) setupProcessGroup(_ *exec.Cmd) {
}
// killProcessGroup is not supported on JS/WASM
func (s *Server) killProcessGroup(*exec.Cmd) {
}
// detectSuPtySupport always returns false on JS/WASM
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
func (s *Server) killProcessGroup(_ *exec.Cmd) {
}

View File

@@ -3,14 +3,12 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strings"
"sync"
"syscall"
"time"
@@ -20,6 +18,47 @@ import (
log "github.com/sirupsen/logrus"
)
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
// TODO: handle pty flag if available
args := []string{"-l", localUser.Username, "-c", command}
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
}
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
// ptyManager manages Pty file operations with thread safety
type ptyManager struct {
file *os.File
@@ -49,7 +88,7 @@ func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
pm.mu.RLock()
defer pm.mu.RUnlock()
if pm.closed {
return errors.New("pty is closed")
return errors.New("Pty is closed")
}
return pty.Setsize(pm.file, ws)
}
@@ -58,78 +97,6 @@ func (pm *ptyManager) File() *os.File {
return pm.file
}
// detectSuPtySupport checks if su supports the --pty flag
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "su", "--help")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("su --help failed (may not support --help): %v", err)
return false
}
supported := strings.Contains(string(output), "--pty")
log.Debugf("su --pty support detected: %v", supported)
return supported
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username, "-c", command)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
}
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
// executeCommandWithPty executes a command with PTY allocation
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
termType := ptyReq.Term
if termType == "" {
termType = "xterm-256color"
}
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil {
@@ -144,12 +111,14 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false
}
logger.Infof("starting interactive shell: %s", execCmd.Path)
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
shell := execCmd.Path
cmd := session.Command()
if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
if err != nil {
logger.Errorf("Pty start failed: %v", err)
@@ -301,29 +270,9 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
pgid := cmd.Process.Pid
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
logger.Debugf("kill process group SIGTERM: %v", err)
return
}
const gracePeriod = 500 * time.Millisecond
const checkInterval = 50 * time.Millisecond
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
timeout := time.After(gracePeriod)
for {
select {
case <-timeout:
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
logger.Debugf("kill process group SIGKILL: %v", err)
}
return
case <-ticker.C:
if err := syscall.Kill(-pgid, 0); err != nil {
return
}
logger.Debugf("kill process group SIGTERM failed: %v", err)
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
logger.Debugf("kill process group SIGKILL failed: %v", err)
}
}
}

View File

@@ -1,3 +1,5 @@
//go:build windows
package server
import (
@@ -7,6 +9,7 @@ import (
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"unsafe"
@@ -31,11 +34,6 @@ func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
}
}()
return s.getUserEnvironmentWithToken(userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil {
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
@@ -268,11 +266,6 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid)
@@ -282,6 +275,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
// Always use user switching on Windows - no direct execution
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true
}
@@ -295,8 +289,45 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
}
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
localUser := privilegeResult.User
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
var command string
rawCmd := session.RawCommand()
if rawCmd != "" {
command = rawCmd
}
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
err := executePtyCommandWithUserToken(session.Context(), session, req)
if err != nil {
logger.Errorf("Windows ConPty with user switching failed: %v", err)
var errorMsg string
if runtime.GOOS == "windows" {
errorMsg = "Windows user switching failed - NetBird must run as a Windows service or with elevated privileges for user switching\r\n"
} else {
errorMsg = "User switching failed - login command not available\r\n"
}
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return
}
logger.Debugf("Windows ConPty command execution with user switching completed")
}
type PtyExecutionRequest struct {
@@ -324,7 +355,7 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
}()
server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
userEnv, err := server.getUserEnvironment(req.Username, req.Domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ()
@@ -377,54 +408,3 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
logger.Debugf("kill process failed: %v", err)
}
}
// detectSuPtySupport always returns false on Windows as su is not available
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
return false
}
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
logger.Debug("ConPty execution completed")
return true
}

Some files were not shown because too many files have changed in this diff Show More