mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 15:43:47 -04:00
Compare commits
295 Commits
nmap/compa
...
dn-reverse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76fb153d76 | ||
|
|
eee4d75932 | ||
|
|
62b8875f67 | ||
|
|
47a5478964 | ||
|
|
9922d6f953 | ||
|
|
f9bab22f61 | ||
|
|
3d8fdb7a89 | ||
|
|
fb10153ab8 | ||
|
|
57d3ee5aac | ||
|
|
cfdfdecc14 | ||
|
|
b00babb8b1 | ||
|
|
ac995bae6d | ||
|
|
41a5509ce0 | ||
|
|
db5e26db94 | ||
|
|
fe975fb834 | ||
|
|
e368d2995b | ||
|
|
a3241d8376 | ||
|
|
6dfc5772ba | ||
|
|
f70925178c | ||
|
|
9554934b92 | ||
|
|
7fdb824a37 | ||
|
|
412407adc0 | ||
|
|
e0874d7de7 | ||
|
|
8df1536cbb | ||
|
|
fcbacc62ec | ||
|
|
ee2ae45653 | ||
|
|
3bc8cbb13f | ||
|
|
bf7bdf6c4f | ||
|
|
6f2f0f9ae4 | ||
|
|
c37ebc6fb3 | ||
|
|
23abb5743c | ||
|
|
0a895ffc22 | ||
|
|
b87aa0bc15 | ||
|
|
69d4b5d821 | ||
|
|
f1a65d732d | ||
|
|
a3c0ea3e71 | ||
|
|
abaf061c2a | ||
|
|
e531fb54b1 | ||
|
|
5fcfed5b16 | ||
|
|
b81837a364 | ||
|
|
5f43449f67 | ||
|
|
6796601aa6 | ||
|
|
1fc25c301b | ||
|
|
08ae281b2d | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
bd47f44c63 | ||
|
|
381260911b | ||
|
|
38db42e7d6 | ||
|
|
5d606d909d | ||
|
|
d689718b50 | ||
|
|
54a73c6649 | ||
|
|
418377842e | ||
|
|
15ef56e03d | ||
|
|
917035f8e8 | ||
|
|
963e3f5457 | ||
|
|
e20b969188 | ||
|
|
1c7059ee67 | ||
|
|
22a3365658 | ||
|
|
2de1949018 | ||
|
|
08ab1e3478 | ||
|
|
ebb1f4007d | ||
|
|
acb53ece93 | ||
|
|
e020950cfd | ||
|
|
9dba262a20 | ||
|
|
5bcdf36377 | ||
|
|
1ffe8deb10 | ||
|
|
d069145bd1 | ||
|
|
f3493ee042 | ||
|
|
b782ac6f56 | ||
|
|
bf48044e5c | ||
|
|
fb4cc37a4a | ||
|
|
55b8d89a79 | ||
|
|
6968a32a5a | ||
|
|
cfe6753349 | ||
|
|
5ae15b3af3 | ||
|
|
b79adb706c | ||
|
|
f22497d5da | ||
|
|
95d672c9df | ||
|
|
7d08a609e6 | ||
|
|
eea6120cd0 | ||
|
|
fc88399c23 | ||
|
|
0cb02bd906 | ||
|
|
08d3867f41 | ||
|
|
b16d63643c | ||
|
|
940d01bdea | ||
|
|
ba9158d159 | ||
|
|
ca9a7e11ef | ||
|
|
a803f47685 | ||
|
|
79fed32f01 | ||
|
|
6b00bb0a66 | ||
|
|
e2adef1eea | ||
|
|
9e5fa11792 | ||
|
|
1ff75acb31 | ||
|
|
1754160686 | ||
|
|
423f6266fb | ||
|
|
16d1b4a14a | ||
|
|
7c14056faf | ||
|
|
62e37dc2e2 | ||
|
|
6a08695ee8 | ||
|
|
9a67a8e427 | ||
|
|
73aa0785ba | ||
|
|
53c1016a8e | ||
|
|
fd442138e6 | ||
|
|
be5f30225a | ||
|
|
7467e9fb8c | ||
|
|
2390c2e46e | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
778c223176 | ||
|
|
36cd0dd85c | ||
|
|
09a1d5a02d | ||
|
|
7c996ac9b5 | ||
|
|
cf9fd5d960 | ||
|
|
1c5ab7cb8f | ||
|
|
aaad3b25a7 | ||
|
|
9904235a2f | ||
|
|
780e9f57a5 | ||
|
|
a8db73285b | ||
|
|
3b43c00d12 | ||
|
|
2f390e1794 | ||
|
|
3630ebb3ae | ||
|
|
260c46df04 | ||
|
|
7f11e3205d | ||
|
|
1c8f92a96f | ||
|
|
7b6294b624 | ||
|
|
156d0b1fef | ||
|
|
2cf00dba58 | ||
|
|
d2a7f3ae36 | ||
|
|
6a64d4e4dd | ||
|
|
51e63c246b | ||
|
|
99e6b1eda4 | ||
|
|
dc26a5a436 | ||
|
|
3883b2fb41 | ||
|
|
ed58659a01 | ||
|
|
5190923c70 | ||
|
|
7c647dd160 | ||
|
|
07e59b2708 | ||
|
|
0a3a9f977d | ||
|
|
2f263bf7e6 | ||
|
|
f65f4fc280 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
adbd7ab4c3 | ||
|
|
0419834482 | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
f797d2d9cb | ||
|
|
5ae7efe8f7 | ||
|
|
d6e35bd0fe | ||
|
|
0e00f1c8f7 | ||
|
|
1b96648d4d | ||
|
|
4433f44a12 | ||
|
|
7504e718d7 | ||
|
|
9b0387e7ee | ||
|
|
d2f9653cea | ||
|
|
5ccce1ab3f | ||
|
|
e366fe340e | ||
|
|
b01809f8e3 | ||
|
|
790ef39187 | ||
|
|
3af16cf333 | ||
|
|
194a986926 | ||
|
|
d09c69f303 | ||
|
|
096d4ac529 | ||
|
|
f7732557fa | ||
|
|
8fafde614a | ||
|
|
694ae13418 | ||
|
|
b5b7dd4f53 | ||
|
|
476785b122 | ||
|
|
907677f835 | ||
|
|
7d844b9410 | ||
|
|
eeabc64a73 | ||
|
|
5da2b0fdcc | ||
|
|
a0005a604e | ||
|
|
a89bb807a6 | ||
|
|
28f3354ffa | ||
|
|
562923c600 | ||
|
|
d488f58311 | ||
|
|
0dd0c67b3b | ||
|
|
ca33849f31 | ||
|
|
18cd0f1480 | ||
|
|
b02982f6b1 | ||
|
|
4d89ae27ef | ||
|
|
733ea77c5c | ||
|
|
92f72bfce6 | ||
|
|
6fdc00ff41 | ||
|
|
bffb25bea7 | ||
|
|
3af4543e80 | ||
|
|
146774860b | ||
|
|
5243481316 | ||
|
|
76a39c1dcb | ||
|
|
02ce918114 | ||
|
|
30cfc22cb6 | ||
|
|
3168afbfcb | ||
|
|
a73ee47557 | ||
|
|
fa6ff005f2 | ||
|
|
095379fa60 | ||
|
|
30572fe1b8 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 | ||
|
|
3a6f364b03 | ||
|
|
5345d716ee | ||
|
|
f882c36e0a | ||
|
|
0c990ab662 | ||
|
|
101c813e98 | ||
|
|
e95cfa1a00 | ||
|
|
5333e55a81 | ||
|
|
0d480071b6 | ||
|
|
8e0b7b6c25 | ||
|
|
81c11df103 | ||
|
|
f204da0d68 | ||
|
|
7d74904d62 | ||
|
|
760ac5e07d | ||
|
|
f74bc48d16 | ||
|
|
4352228797 | ||
|
|
0169e4540f | ||
|
|
74c770609c | ||
|
|
f4ca36ed7e | ||
|
|
c86da92fc6 | ||
|
|
3f0c577456 | ||
|
|
717da8c7b7 | ||
|
|
a0a61d4f47 | ||
|
|
cead3f38ee | ||
|
|
5b1fced872 | ||
|
|
c98dcf5ef9 | ||
|
|
57cb6bfccb | ||
|
|
95bf97dc3c | ||
|
|
3d116c9d33 | ||
|
|
b55262d4a2 | ||
|
|
a9ce9f8d5a | ||
|
|
10b981a855 | ||
|
|
7700b4333d | ||
|
|
7d0131111e | ||
|
|
1daea35e4b | ||
|
|
f97544af0d | ||
|
|
231e80cc15 | ||
|
|
a4c1362bff | ||
|
|
b611d4a751 | ||
|
|
2248ff392f | ||
|
|
2c9decfa55 | ||
|
|
3c5ac17e2f | ||
|
|
ae42bbb898 | ||
|
|
b86722394b | ||
|
|
a103f69767 | ||
|
|
73fbb3fc62 | ||
|
|
7b3523e25e | ||
|
|
6e4e1386e7 | ||
|
|
671e9af6eb | ||
|
|
50f42caf94 | ||
|
|
b7eeefc102 | ||
|
|
8dd22f3a4f | ||
|
|
4b89427447 | ||
|
|
b71e2860cf | ||
|
|
160b27bc60 | ||
|
|
c084386b88 | ||
|
|
06966da012 | ||
|
|
d4f7df271a | ||
|
|
6889047350 | ||
|
|
245bbb4acf | ||
|
|
2b2fc02d83 | ||
|
|
5299549eb6 | ||
|
|
7d791620a6 | ||
|
|
703ef29199 | ||
|
|
44ab454a13 | ||
|
|
11f50d6c38 | ||
|
|
b0b60b938a | ||
|
|
e3a026bf1c | ||
|
|
94503465ee | ||
|
|
8d959b0abc | ||
|
|
05af39a69b | ||
|
|
1d8390b935 | ||
|
|
074df56c3d | ||
|
|
2381e216e4 | ||
|
|
ded04b7627 | ||
|
|
67211010f7 | ||
|
|
c61568ceb4 | ||
|
|
737d6061bf | ||
|
|
ee3a67d2d8 | ||
|
|
1a32e4c223 | ||
|
|
269d5d1cba | ||
|
|
2851e38a1f | ||
|
|
51261fe7a9 | ||
|
|
304321d019 | ||
|
|
f8c3295645 | ||
|
|
183619d1e1 | ||
|
|
3b832d1f21 | ||
|
|
fcb849698f | ||
|
|
7527e0ebdb | ||
|
|
ed5f98da5b | ||
|
|
12b38e25da | ||
|
|
626e892e3b |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
||||
.env
|
||||
.env.*
|
||||
*.pem
|
||||
*.key
|
||||
*.crt
|
||||
*.p12
|
||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -39,11 +39,11 @@ jobs:
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
|
||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
||||
51
.github/workflows/golang-test-linux.yml
vendored
51
.github/workflows/golang-test-linux.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -204,7 +204,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -261,6 +261,53 @@ jobs:
|
||||
-exec 'sudo' \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
|
||||
test_proxy:
|
||||
name: "Proxy / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test -timeout 10m -p 1 ./proxy/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
needs: [build-cache]
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||
skip: go.mod,go.sum
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.0"
|
||||
SIGN_PIPE_VER: "v0.1.1"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
.run
|
||||
*.iml
|
||||
dist/
|
||||
!proxy/web/dist/
|
||||
bin/
|
||||
.env
|
||||
conf.json
|
||||
|
||||
@@ -60,8 +60,8 @@
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
|
||||
@@ -3,15 +3,7 @@ package android
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
var needsLogin bool
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.OnLoginSuccess()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -277,18 +276,15 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check login required: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -300,23 +296,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -344,11 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -30,6 +31,14 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
const (
|
||||
// PeerStatusConnected indicates the peer is in connected state.
|
||||
PeerStatusConnected = peer.StatusConnected
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -68,6 +77,10 @@ type Options struct {
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -136,6 +149,8 @@ func New(opts Options) (*Client, error) {
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
@@ -155,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
||||
setupKey: opts.SetupKey,
|
||||
jwtToken: opts.JWTToken,
|
||||
config: config,
|
||||
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -176,13 +192,17 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
|
||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create auth client: %w", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
@@ -335,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
@@ -350,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return recorder.GetFullStatus(), nil
|
||||
return c.recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
|
||||
@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||
}
|
||||
|
||||
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("create chain: %w", err)
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -168,6 +168,10 @@ type Manager interface {
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -48,8 +49,10 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.Flush()
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||
return fmt.Errorf("notrack chains not initialized")
|
||||
}
|
||||
|
||||
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||
loopback := []byte{127, 0, 0, 1}
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush notrack rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawPrerouting,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush chain creation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) refreshNoTrackChains() error {
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, c := range chains {
|
||||
if c.Table.Name != tableName {
|
||||
continue
|
||||
}
|
||||
switch c.Name {
|
||||
case chainNameRawOutput:
|
||||
m.notrackOutputChain = c
|
||||
case chainNameRawPrerouting:
|
||||
m.notrackPreroutingChain = c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if nftRule.Handle == 0 {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||
keys := []string{
|
||||
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||
}
|
||||
for _, key := range keys {
|
||||
rule, ok := r.rules[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||
}
|
||||
delete(r.rules, key)
|
||||
}
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||
}
|
||||
|
||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||
// counters will be off until the next successful removal or refresh cycle.
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
// TODO: rollback set counter
|
||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||
func (r *router) refreshRulesMap() error {
|
||||
var merr *multierror.Error
|
||||
newRules := make(map[string]*nftables.Rule)
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list rules: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||
// preserve existing entries for this chain since we can't verify their state
|
||||
for k, v := range r.rules {
|
||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||
newRules[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
newRules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
r.rules = newRules
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
if dnatRule.Handle == 0 {
|
||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if err := r.conn.DelRule(masqRule); err != nil {
|
||||
if masqRule.Handle == 0 {
|
||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
if needsFlush {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
}
|
||||
|
||||
if merr == nil {
|
||||
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Add a real rule to the kernel
|
||||
ruleKey, err := r.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
firewall.ProtocolTCP,
|
||||
nil,
|
||||
&firewall.Port{Values: []uint16{80}},
|
||||
firewall.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||
})
|
||||
|
||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||
staleKey := "stale-rule-that-does-not-exist"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
}
|
||||
|
||||
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||
|
||||
err = r.refreshRulesMap()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||
|
||||
realRule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "real rule should still exist after refresh")
|
||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||
}
|
||||
|
||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Inject a stale entry with Handle=0
|
||||
staleKey := "stale-route-rule"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
}
|
||||
|
||||
// DeleteRouteRule should not return an error for stale handles
|
||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||
}
|
||||
|
||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "staletest",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
rtr := manager.router
|
||||
|
||||
// First add succeeds
|
||||
err = rtr.AddNatRule(pair)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||
})
|
||||
|
||||
// Corrupt the handle to simulate stale state
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
|
||||
// Adding the same rule again should succeed despite stale handles
|
||||
err = rtr.AddNatRule(pair)
|
||||
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||
|
||||
// Verify rules exist in kernel
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err)
|
||||
|
||||
found := 0
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||
}
|
||||
|
||||
@@ -3,12 +3,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
m.resetState()
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
m.resetState()
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
|
||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
||||
return t.tombstone.Load()
|
||||
}
|
||||
|
||||
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||
if t.tombstone.Load() {
|
||||
return true
|
||||
}
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||
}
|
||||
|
||||
// SetTombstone safely marks the connection for deletion
|
||||
func (t *TCPConnTrack) SetTombstone() {
|
||||
t.tombstone.Store(true)
|
||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
if exists && !conn.IsSupersededBy(flags) {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.IsTombstone() {
|
||||
if !exists || conn.IsSupersededBy(flags) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and gracefully close a connection (server-initiated close)
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Server sends FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Client sends FIN-ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
|
||||
// Server sends final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Connection should be tombstoned
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn, "old connection should still be in map")
|
||||
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||
|
||||
// Now reuse the same port for a new connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
// The old tombstoned entry should be replaced with a new one
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn, "new connection should exist")
|
||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||
|
||||
// SYN-ACK for the new connection should be valid
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
|
||||
// Data transfer should work
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||
require.True(t, valid, "data should be allowed on new connection")
|
||||
})
|
||||
|
||||
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and RST a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||
|
||||
// Reuse the same port
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn)
|
||||
require.False(t, newConn.IsTombstone())
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||
})
|
||||
|
||||
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := srcIP
|
||||
serverIP := dstIP
|
||||
clientPort := srcPort
|
||||
serverPort := dstPort
|
||||
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||
|
||||
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Server-initiated close to reach Closed/tombstoned:
|
||||
// Server FIN (opposite dir) → CloseWait
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
// Client FIN-ACK (same dir as conn) → LastAck
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
require.True(t, conn.IsTombstone())
|
||||
|
||||
// New inbound connection on same ports
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn)
|
||||
require.False(t, newConn.IsTombstone())
|
||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||
|
||||
// Complete handshake: server SYN-ACK, then client ACK
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
})
|
||||
|
||||
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.True(t, conn.IsTombstone())
|
||||
|
||||
// Late ACK should be rejected (tombstoned)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||
|
||||
// Late outbound ACK should not create a new connection (not a SYN)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Active close: client (outbound initiator) sends FIN first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Server ACKs the FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Server sends its own FIN
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||
|
||||
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn, "new connection should exist")
|
||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||
|
||||
// SYN-ACK for new connection should be valid
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
})
|
||||
|
||||
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish outbound connection and close via active close → TIME-WAIT
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||
|
||||
// Simulate what the filter does next: TrackInbound via the normal path
|
||||
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||
newConn := tracker.connections[invertedKey]
|
||||
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||
require.False(t, newConn.IsTombstone())
|
||||
})
|
||||
|
||||
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and active close → TIME-WAIT
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,11 +13,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
@@ -24,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -89,6 +93,7 @@ type Manager struct {
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
id: string(ruleKey),
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
||||
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.routeRulesMap[ruleKey] = &rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
}
|
||||
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
return r.id == string(ruleKey)
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
delete(m.routeRulesMap, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -570,6 +586,48 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
maps.Clear(m.outgoingRules)
|
||||
maps.Clear(m.incomingDenyRules)
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
|
||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.1.0/24"),
|
||||
netip.MustParsePrefix("100.64.2.0/24"),
|
||||
}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
// These should be the same (idempotent) like nftables/iptables implementations
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||
// different parameters get distinct IDs.
|
||||
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||
"Different rules should have different IDs")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||
// rule during a network map update does not disrupt existing traffic.
|
||||
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.True(t, passAfter,
|
||||
"Traffic should still pass after rule update - no gap should occur")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||
// returns the same rule without duplicating.
|
||||
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call blockInvalidRouted directly multiple times
|
||||
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule3)
|
||||
|
||||
// All should return the same rule
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||
|
||||
// Should have exactly 1 route rule
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||
|
||||
// Verify the rule blocks traffic to the WG network
|
||||
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||
}
|
||||
|
||||
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||
// EnableRouting multiple times (as happens on each route update) does not
|
||||
// accumulate duplicate block rules in the routeRules slice.
|
||||
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount,
|
||||
"Repeated EnableRouting should not accumulate block rules")
|
||||
}
|
||||
|
||||
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||
// rule multiple times does not create duplicate entries.
|
||||
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||
}
|
||||
|
||||
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||
// after adding it multiple times works correctly.
|
||||
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||
}
|
||||
|
||||
func setupTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
}
|
||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||
}
|
||||
|
||||
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||
// IP are stored in separate maps and don't interfere with each other.
|
||||
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
require.False(t, allowExists, "Allow rules should be empty")
|
||||
}
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,9 +18,18 @@ const (
|
||||
maxBatchSize = 1024 * 16
|
||||
maxMessageSize = 1024 * 2
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
logChannelSize = 1000
|
||||
defaultLogChanSize = 1000
|
||||
)
|
||||
|
||||
func getLogChannelSize() int {
|
||||
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultLogChanSize
|
||||
}
|
||||
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
@@ -69,7 +80,7 @@ type Logger struct {
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
msgChannel: make(chan logMessage, logChannelSize),
|
||||
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
|
||||
@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
host, portStr, err := net.SplitHostPort(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
|
||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
mutex sync.RWMutex
|
||||
filter PacketFilter
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// newDeviceFilter constructor function
|
||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying tun device exactly once.
|
||||
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||
// and multiple code paths can trigger Close on the same device.
|
||||
func (d *FilteredDevice) Close() error {
|
||||
var err error
|
||||
d.closeOnce.Do(func() {
|
||||
err = d.Device.Close()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read wraps read method with filtering feature
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
|
||||
@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
if cErr := tunIface.Close(); cErr != nil {
|
||||
log.Debugf("failed to close tun device: %v", cErr)
|
||||
}
|
||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
@@ -50,6 +51,7 @@ func ValidateMTU(mtu uint16) error {
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
Free() error
|
||||
}
|
||||
|
||||
@@ -80,6 +82,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||
func (w *WGIface) GetProxyPort() uint16 {
|
||||
return w.wgProxyFactory.GetProxyPort()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
@@ -221,6 +229,10 @@ func (w *WGIface) Close() error {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
if nbnetstack.IsEnabled() {
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
if err := w.waitUntilRemoved(); err != nil {
|
||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||
if err := w.Destroy(); err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return nsTunDev, tunNet, nil
|
||||
return t.tundev, tunNet, nil
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Close() error {
|
||||
|
||||
@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to start package redirection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||
p.wgCurrentUsed = ep
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
@@ -212,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid address")
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -26,13 +24,10 @@ const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
proxyPort int
|
||||
mtu uint16
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
rawConnIPv4 net.PacketConn
|
||||
rawConnIPv6 net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
proxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.proxyPort = proxyPort
|
||||
|
||||
// Prepare IPv4 raw socket (required)
|
||||
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
// Prepare IPv6 raw socket (optional)
|
||||
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||
}
|
||||
if p.rawConnIPv6 != nil {
|
||||
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
Port: proxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
if p.rawConnIPv4 != nil {
|
||||
if err := p.rawConnIPv4.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.rawConnIPv6 != nil {
|
||||
if err := p.rawConnIPv6.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy listening port.
|
||||
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||
return uint16(p.proxyPort)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@@ -216,34 +241,3 @@ generatePort:
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,12 +10,89 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
var (
|
||||
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||
|
||||
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||
localHostNetIPv6 = net.ParseIP("::1")
|
||||
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
)
|
||||
|
||||
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||
type PacketHeaders struct {
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH *layers.UDP
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr net.IP
|
||||
isIPv4 bool
|
||||
}
|
||||
|
||||
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
var localHostAddr net.IP
|
||||
var isIPv4 bool
|
||||
|
||||
// Check if source address is IPv4 or IPv6
|
||||
if endpoint.IP.To4() != nil {
|
||||
// IPv4 path
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPv4,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
localHostAddr = localHostNetIPv4
|
||||
isIPv4 = true
|
||||
} else {
|
||||
// IPv6 path
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPv6,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
localHostAddr = localHostNetIPv6
|
||||
isIPv4 = false
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpoint.Port),
|
||||
DstPort: layers.UDPPort(localWGListenPort),
|
||||
}
|
||||
|
||||
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return &PacketHeaders{
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
isIPv4: isIPv4,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
headers *PacketHeaders
|
||||
headerCurrentUsed *PacketHeaders
|
||||
rawConn net.PacketConn
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
|
||||
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create packet sender: %w", err)
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
return errIPv6ConnNotAvailable
|
||||
}
|
||||
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
return errIPv4ConnNotAvailable
|
||||
}
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
return err
|
||||
p.headers = headers
|
||||
p.rawConn = p.selectRawConn(headers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
p.headerCurrentUsed = p.headers
|
||||
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
@@ -91,10 +188,32 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
if endpoint == nil || endpoint.IP == nil {
|
||||
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||
return
|
||||
}
|
||||
|
||||
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create packet headers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
log.Error(errIPv6ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
log.Error(errIPv4ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
p.headerCurrentUsed = header
|
||||
p.rawConn = p.selectRawConn(header)
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
@@ -136,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -162,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||
defer func() {
|
||||
if err := header.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
|
||||
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||
if header.isIPv4 {
|
||||
return p.wgeBPFProxy.rawConnIPv4
|
||||
}
|
||||
return p.wgeBPFProxy.rawConnIPv6
|
||||
}
|
||||
|
||||
@@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||
if w.ebpfProxy == nil {
|
||||
return 0
|
||||
}
|
||||
return w.ebpfProxy.GetProxyPort()
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
|
||||
@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
}
|
||||
|
||||
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||
func (w *USPFactory) GetProxyPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,43 +8,87 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||
}
|
||||
|
||||
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||
}
|
||||
|
||||
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||
if isIPv4 {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
353
client/iface/wgproxy/redirect_test.go
Normal file
353
client/iface/wgproxy/redirect_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return addr1.String() == addr2.String()
|
||||
}
|
||||
|
||||
// Compare IP and Port, ignoring zone
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||
wgPort := 51853
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||
// It verifies that:
|
||||
// 1. Initial traffic from relay connection works
|
||||
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||
// 3. Multiple packets are correctly redirected with the new source address
|
||||
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener4.Close()
|
||||
|
||||
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener6.Close()
|
||||
|
||||
// Determine which listener to use based on the NetBird address IP version
|
||||
// (this is where initial traffic will come from before RedirectAs is called)
|
||||
var wgListener *net.UDPConn
|
||||
if p2pEndpoint.IP.To4() == nil {
|
||||
wgListener = wgListener6
|
||||
} else {
|
||||
wgListener = wgListener4
|
||||
}
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0, // Random port
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
// Add TURN connection to proxy
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the proxy
|
||||
proxy.Work()
|
||||
|
||||
// Phase 1: Test initial relay traffic
|
||||
msgFromRelay := []byte("hello from relay")
|
||||
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write to relay server: %v", err)
|
||||
}
|
||||
|
||||
// Set read deadline to avoid hanging
|
||||
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _, err := wgListener4.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msgFromRelay) {
|
||||
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||
}
|
||||
|
||||
// Phase 2: Redirect to p2p endpoint
|
||||
proxy.RedirectAs(p2pEndpoint)
|
||||
|
||||
// Give the proxy a moment to process the redirect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Phase 3: Test redirected traffic
|
||||
redirectedMessages := [][]byte{
|
||||
[]byte("redirected message 1"),
|
||||
[]byte("redirected message 2"),
|
||||
[]byte("redirected message 3"),
|
||||
}
|
||||
|
||||
for i, msg := range redirectedMessages {
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
// Verify message content
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||
}
|
||||
|
||||
// Verify source address matches p2p endpoint (this is the key test)
|
||||
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||
t.Errorf("message %d: expected source address %s, got %s",
|
||||
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listener
|
||||
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener.Close()
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
// Test switching between multiple endpoints - using addresses in local subnet
|
||||
endpoints := []*net.UDPAddr{
|
||||
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||
}
|
||||
|
||||
for i, endpoint := range endpoints {
|
||||
proxy.RedirectAs(endpoint)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg := []byte("test message")
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||
}
|
||||
|
||||
if !compareUDPAddr(srcAddr, endpoint) {
|
||||
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||
i, endpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
|
||||
@@ -19,37 +19,56 @@ var (
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
localHostNetIPAddrV4 = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
localHostNetIPAddrV6 = &net.IPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr *net.IPAddr
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
// Create only the raw socket for the address family we need
|
||||
var rawSocket net.PacketConn
|
||||
var err error
|
||||
var localHostAddr *net.IPAddr
|
||||
|
||||
if srcAddr.IP.To4() != nil {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
localHostAddr = localHostNetIPAddrV4
|
||||
} else {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
localHostAddr = localHostNetIPAddrV6
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
}
|
||||
|
||||
return f, nil
|
||||
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
|
||||
// Check if source IP is IPv4 or IPv6
|
||||
if srcAddr.IP.To4() != nil {
|
||||
// IPv4
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPAddrV4.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
} else {
|
||||
// IPv6
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPAddrV6.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
}
|
||||
|
||||
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||
}
|
||||
|
||||
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: add deny and accept rules
|
||||
networkMap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap1, false)
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||
|
||||
// Second update: remove the deny rule, keep only accept
|
||||
networkMap2 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap2, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Should have 1 rule after removing deny rule")
|
||||
|
||||
// Third update: remove all rules
|
||||
networkMap3 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap3, false)
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||
"Should have 0 rules after removing all rules")
|
||||
}
|
||||
|
||||
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: accept rule
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Second update: change to deny (same IP/port/proto, different action)
|
||||
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Changing action should result in exactly 1 rule, not 2")
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
mutex sync.RWMutex
|
||||
client *mgm.GrpcClient
|
||||
config *profilemanager.Config
|
||||
privateKey wgtypes.Key
|
||||
mgmURL *url.URL
|
||||
mgmTLSEnabled bool
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth instance that manages authentication flows
|
||||
// It establishes a connection to the management server that will be reused for all operations
|
||||
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||
// Validate WireGuard private key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine TLS setting based on URL scheme
|
||||
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
return &Auth{
|
||||
client: mgmClient,
|
||||
config: config,
|
||||
privateKey: myPrivateKey,
|
||||
mgmURL: mgmURL,
|
||||
mgmTLSEnabled: mgmTLSEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the management client connection
|
||||
func (a *Auth) Close() error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.client == nil {
|
||||
return nil
|
||||
}
|
||||
return a.client.Close()
|
||||
}
|
||||
|
||||
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||
var supportsSSO bool
|
||||
|
||||
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
// Try PKCE flow first
|
||||
_, err := a.getPKCEFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if PKCE is not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// PKCE not supported, try Device flow
|
||||
_, err = a.getDeviceFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Device flow is also not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// Neither PKCE nor Device flow is supported
|
||||
supportsSSO = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Device flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
}
|
||||
|
||||
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
})
|
||||
|
||||
return supportsSSO, err
|
||||
}
|
||||
|
||||
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||
// This avoids creating a new connection to the management server
|
||||
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||
var flow OAuthFlow
|
||||
var err error
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
if forceDeviceAuth {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
|
||||
// Try PKCE flow first
|
||||
flow, err = a.getPKCEFlow(client)
|
||||
if err != nil {
|
||||
// If PKCE not supported, try Device flow
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return flow, err
|
||||
}
|
||||
|
||||
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
needsLogin = false
|
||||
return err
|
||||
})
|
||||
|
||||
return needsLogin, err
|
||||
}
|
||||
|
||||
// Login attempts to log in or register the client with the management server
|
||||
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
|
||||
isAuthError = false
|
||||
return nil
|
||||
})
|
||||
|
||||
return err, isAuthError
|
||||
}
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
ClientCertPair: a.config.ClientCertKeyPair,
|
||||
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||
}
|
||||
|
||||
if err := validatePKCEConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
}
|
||||
|
||||
// Keep compatibility with older management versions
|
||||
if config.Scope == "" {
|
||||
config.Scope = "openid"
|
||||
}
|
||||
|
||||
if err := validateDeviceAuthConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// reconnect closes the current connection and creates a new one
|
||||
// It checks if the brokenClient is still the current client before reconnecting
|
||||
// to avoid multiple threads reconnecting unnecessarily
|
||||
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||
if a.client != brokenClient {
|
||||
log.Debugf("client already reconnected by another thread, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new connection FIRST, before closing the old one
|
||||
// This ensures a.client is never nil, preventing panics in other threads
|
||||
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||
// Keep the old client if reconnection fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Close old connection AFTER new one is successfully created
|
||||
oldClient := a.client
|
||||
a.client = mgmClient
|
||||
|
||||
if oldClient != nil {
|
||||
if err := oldClient.Close(); err != nil {
|
||||
log.Debugf("error closing old connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// These error codes indicate connection issues
|
||||
return s.Code() == codes.Unavailable ||
|
||||
s.Code() == codes.DeadlineExceeded ||
|
||||
s.Code() == codes.Canceled ||
|
||||
s.Code() == codes.Internal
|
||||
}
|
||||
|
||||
// withRetry wraps an operation with exponential backoff retry logic
|
||||
// It automatically reconnects on connection errors
|
||||
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||
backoffSettings := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 2 * time.Minute,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
backoffSettings.Reset()
|
||||
|
||||
return backoff.RetryNotify(
|
||||
func() error {
|
||||
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||
a.mutex.RLock()
|
||||
currentClient := a.client
|
||||
a.mutex.RUnlock()
|
||||
|
||||
if currentClient == nil {
|
||||
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||
}
|
||||
|
||||
// Execute operation with the captured client
|
||||
err := operation(currentClient)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||
if isConnectionError(err) {
|
||||
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||
|
||||
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||
return reconnectErr
|
||||
}
|
||||
// Return the original error to trigger retry with the new connection
|
||||
return err
|
||||
}
|
||||
|
||||
// For authentication errors, don't retry
|
||||
if isAuthenticationError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
backoff.WithContext(backoffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||
func isAuthenticationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||
func isPermissionDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
return isAuthenticationError(err)
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
@@ -26,12 +25,56 @@ const (
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
providerConfig DeviceAuthProviderConfig
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the device authorization flow
|
||||
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||
d.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
pkceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
return pkceFlowInfo, nil
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
deviceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
return deviceFlowInfo, nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
@@ -35,17 +34,67 @@ const (
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validatePKCEConfig validates PKCE provider configuration
|
||||
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
providerConfig PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||
p.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
@@ -228,6 +228,7 @@ type BundleGenerator struct {
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
|
||||
anonymize bool
|
||||
includeSystemInfo bool
|
||||
@@ -248,6 +249,7 @@ type GeneratorDependencies struct {
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -265,6 +267,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
@@ -408,6 +411,10 @@ func (g *BundleGenerator) addStatus() error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
if g.refreshStatus != nil {
|
||||
g.refreshStatus()
|
||||
}
|
||||
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD exact match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD subdomain match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD wildcard match",
|
||||
handlerDomain: "*.example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "two letter domain labels",
|
||||
handlerDomain: "a.b.",
|
||||
queryDomain: "a.b.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain with subdomain match",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "sub.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -38,6 +40,9 @@ const (
|
||||
type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
systemDNSSettings SystemDNSSettings
|
||||
|
||||
mu sync.RWMutex
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager() (*systemConfigurator, error) {
|
||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
}
|
||||
|
||||
var dnsSettings SystemDNSSettings
|
||||
var serverAddresses []netip.Addr
|
||||
inSearchDomainsArray := false
|
||||
inServerAddressesArray := false
|
||||
|
||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip.Unmap()
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||
ip = ip.Unmap()
|
||||
serverAddresses = append(serverAddresses, ip)
|
||||
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = DefaultPort
|
||||
|
||||
s.mu.Lock()
|
||||
s.origNameservers = serverAddresses
|
||||
s.mu.Unlock()
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
|
||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func TestGetOriginalNameservers(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
origNameservers: []netip.Addr{
|
||||
netip.MustParseAddr("8.8.8.8"),
|
||||
netip.MustParseAddr("1.1.1.1"),
|
||||
},
|
||||
}
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
assert.Len(t, servers, 2)
|
||||
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||
}
|
||||
|
||||
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
|
||||
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||
|
||||
for _, server := range servers {
|
||||
assert.True(t, server.IsValid(), "server address should be valid")
|
||||
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||
}
|
||||
|
||||
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||
}
|
||||
|
||||
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
}
|
||||
|
||||
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
routeAll bool
|
||||
}{
|
||||
{"routeall_false", false},
|
||||
{"routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
for _, srv := range initialServers {
|
||||
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.routeAll,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialRoute bool
|
||||
}{
|
||||
{"start_with_routeall_false", false},
|
||||
{"start_with_routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.initialRoute,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
// First apply
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle RouteAll
|
||||
config.RouteAll = !tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle back
|
||||
config.RouteAll = tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
for _, srv := range servers {
|
||||
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
logger.Debug("received local resolver request with no question")
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -27,6 +29,8 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||
skipProbe, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||
}
|
||||
if skipProbe {
|
||||
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
wg.Add(1)
|
||||
@@ -615,7 +630,7 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
s.registerFallback(config)
|
||||
}
|
||||
|
||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
||||
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||
if !ok {
|
||||
@@ -624,6 +639,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -8,15 +8,21 @@ import (
|
||||
|
||||
type MockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
lastResponse *dns.Msg
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
rw.lastResponse = m
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||
return rw.lastResponse
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
|
||||
@@ -71,6 +71,11 @@ type upstreamResolverBase struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
@@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
u.prepareRequest(r)
|
||||
|
||||
@@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if u.tryUpstreamServers(w, r, logger) {
|
||||
return
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
if !ok {
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
@@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if u.queryUpstream(w, r, upstream, timeout, logger) {
|
||||
return true
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
}
|
||||
}
|
||||
return false
|
||||
return false, failures
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
@@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
|
||||
return false
|
||||
return u.handleUpstreamError(err, upstream, startTime)
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
return false
|
||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
|
||||
return
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
|
||||
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
|
||||
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
||||
timeoutMsg += " " + peerInfo
|
||||
reason += " " + peerInfo
|
||||
}
|
||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||
logger.Warn(timeoutMsg)
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
@@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||
totalUpstreams := len(u.upstreamServers)
|
||||
failedCount := len(failures)
|
||||
failureSummary := formatFailures(failures)
|
||||
|
||||
if succeeded {
|
||||
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||
} else {
|
||||
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func formatFailures(failures []upstreamFailure) string {
|
||||
parts := make([]string, 0, len(failures))
|
||||
for _, f := range failures {
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
@@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
responses map[string]mockUpstreamResponse
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolver{
|
||||
err: dns.ErrTime,
|
||||
@@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
successAnswer := "192.0.2.100"
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstream1 mockUpstreamResponse
|
||||
upstream2 mockUpstreamResponse
|
||||
expectedRcode int
|
||||
expectAnswer bool
|
||||
expectTrySecond bool
|
||||
}{
|
||||
{
|
||||
name: "success on first upstream",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: false,
|
||||
},
|
||||
{
|
||||
name: "SERVFAIL from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "REFUSED from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "NXDOMAIN from first should NOT try second",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: false,
|
||||
},
|
||||
{
|
||||
name: "timeout from first should try second",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "no response from first should try second",
|
||||
upstream1: mockUpstreamResponse{msg: nil},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectAnswer: true,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "both upstreams return SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "both upstreams timeout",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first SERVFAIL then timeout",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
upstream2: mockUpstreamResponse{err: timeoutErr},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first timeout then SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{err: timeoutErr},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
{
|
||||
name: "first REFUSED then SERVFAIL",
|
||||
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectAnswer: false,
|
||||
expectTrySecond: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var queriedUpstreams []string
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream1.String(): tc.upstream1,
|
||||
upstream2.String(): tc.upstream2,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
trackingClient := &trackingMockClient{
|
||||
inner: mockClient,
|
||||
queriedUpstreams: &queriedUpstreams,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: trackingClient,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
|
||||
|
||||
if tc.expectAnswer {
|
||||
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
|
||||
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||
}
|
||||
|
||||
if tc.expectTrySecond {
|
||||
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
|
||||
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
|
||||
} else {
|
||||
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
|
||||
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type trackingMockClient struct {
|
||||
inner *mockUpstreamResolverPerServer
|
||||
queriedUpstreams *[]string
|
||||
}
|
||||
|
||||
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
|
||||
return t.inner.exchange(ctx, upstream, r)
|
||||
}
|
||||
|
||||
func buildMockResponse(rcode int, answer string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.Response = true
|
||||
m.Rcode = rcode
|
||||
|
||||
if rcode == dns.RcodeSuccess && answer != "" {
|
||||
m.Answer = []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "example.com.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: net.ParseIP(answer),
|
||||
},
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
upstream := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamServers: []netip.AddrPort{upstream},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
failures []upstreamFailure
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty slice",
|
||||
failures: []upstreamFailure{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single failure",
|
||||
failures: []upstreamFailure{
|
||||
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||
},
|
||||
expected: "8.8.8.8:53=SERVFAIL",
|
||||
},
|
||||
{
|
||||
name: "multiple failures",
|
||||
failures: []upstreamFailure{
|
||||
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
|
||||
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
|
||||
},
|
||||
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := formatFailures(tc.failures)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||
if len(query.Question) == 0 {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
question := query.Question[0]
|
||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
qname := strings.ToLower(question.Name)
|
||||
|
||||
domain := strings.ToLower(question.Name)
|
||||
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
|
||||
resp := query.SetReply(query)
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
if network == "" {
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||
// query doesn't match any configured domain
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||
if mostSpecificResId == "" {
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||
return nil
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||
f.cache.set(domain, question.Qtype, result.IPs)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||
f.cache.set(qname, question.Qtype, result.IPs)
|
||||
|
||||
return resp
|
||||
f.writeResponse(logger, w, resp, qname, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
type udpResponseWriter struct {
|
||||
dns.ResponseWriter
|
||||
query *dns.Msg
|
||||
}
|
||||
|
||||
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
opt := u.query.IsEdns0()
|
||||
maxSize := dns.MinMsgSize
|
||||
if opt != nil {
|
||||
maxSize = int(opt.UDPSize())
|
||||
}
|
||||
|
||||
if resp.Len() > maxSize {
|
||||
resp.Truncate(maxSize)
|
||||
}
|
||||
|
||||
return u.ResponseWriter.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt := query.IsEdns0()
|
||||
maxSize := dns.MinMsgSize
|
||||
if opt != nil {
|
||||
// client advertised a larger EDNS0 buffer
|
||||
maxSize = int(opt.UDPSize())
|
||||
}
|
||||
|
||||
// if our response is too big, truncate and set the TC bit
|
||||
if resp.Len() > maxSize {
|
||||
resp.Truncate(maxSize)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
// NotFound: cache negative result and respond
|
||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||
resp.Rcode = verifyResult.Rcode
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
||||
// No cache or verification failed. Log with or without the server field for more context.
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
} else {
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||
|
||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockFirewall.AssertExpectations(t)
|
||||
mockResolver.AssertExpectations(t)
|
||||
} else {
|
||||
if resp != nil {
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
}
|
||||
require.NotNil(t, resp, "Expected response")
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||
}
|
||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||
|
||||
// Verify response
|
||||
resp := mockWriter.GetLastResponse()
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.NotEmpty(t, resp.Answer)
|
||||
} else if resp != nil {
|
||||
} else {
|
||||
require.NotNil(t, resp, "Expected response")
|
||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||
"Unauthorized domain should be refused or have no answers")
|
||||
}
|
||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
// Verify response contains all IPs
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||
resp1 := w1.GetLastResponse()
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
w2 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
resp2 := w2.GetLastResponse()
|
||||
require.NotNil(t, resp2, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||
require.Len(t, resp2.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||
resp1 := w1.GetLastResponse()
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
w2 := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
resp2 := w2.GetLastResponse()
|
||||
require.NotNil(t, resp2)
|
||||
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||
require.Len(t, resp2.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
|
||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||
if resp != nil && writtenResp == nil {
|
||||
writtenResp = resp
|
||||
}
|
||||
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||
resp := mockWriter.GetLastResponse()
|
||||
require.NotNil(t, resp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||
|
||||
if tt.expectNoAnswer {
|
||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||
}
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
writeCalled := false
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writeCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
|
||||
assert.Nil(t, resp, "Should return nil for empty query")
|
||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
@@ -505,6 +506,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
// Set up notrack rules immediately after proxy is listening to prevent
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -539,11 +544,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.shutdownWg.Add(1)
|
||||
wgIfaceName := e.wgInterface.Name()
|
||||
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.triggerClientRestart()
|
||||
} else if err != nil {
|
||||
@@ -569,9 +575,11 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("create firewall manager: %w", err)
|
||||
}
|
||||
if e.firewall == nil {
|
||||
return fmt.Errorf("create firewall manager: received nil manager")
|
||||
}
|
||||
|
||||
if err := e.initFirewall(); err != nil {
|
||||
@@ -617,6 +625,23 @@ func (e *Engine) initFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||
func (e *Engine) setupWGProxyNoTrack() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
proxyPort := e.wgInterface.GetProxyPort()
|
||||
if proxyPort == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
@@ -805,6 +830,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
log.Infof("sync finished in %s", time.Since(started))
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -994,7 +1023,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||
state.FQDN = conf.GetFqdn()
|
||||
|
||||
e.statusRecorder.UpdateLocalPeerState(state)
|
||||
@@ -1050,6 +1079,9 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
}
|
||||
|
||||
bundleJobParams := debug.BundleConfig{
|
||||
@@ -1641,6 +1673,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
@@ -1827,7 +1860,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1841,23 +1874,8 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
turns := slices.Clone(e.TURNs)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
stats, err := e.wgInterface.GetStats()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get wireguard stats: %v", err)
|
||||
e.syncMsgMux.Unlock()
|
||||
return false
|
||||
}
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
wgStats, ok := stats[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
@@ -1906,7 +1924,7 @@ func (e *Engine) triggerClientRestart() {
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor {
|
||||
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
if netstack.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
if netstack.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
|
||||
@@ -107,6 +107,7 @@ type MockWGIface struct {
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetProxyPortFunc func() uint16
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
@@ -203,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||
if m.GetProxyPortFunc != nil {
|
||||
return m.GetProxyPortFunc()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is only used on Windows and JS platforms:
|
||||
// BindListener is used on Windows, JS, and netstack platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// BindListener bypasses this by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||
// BindListener bypasses these issues by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
||||
mgmURL := config.ManagementURL
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
info.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
|
||||
@@ -2,6 +2,7 @@ package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- a.Agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if agent == nil {
|
||||
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
|
||||
}
|
||||
|
||||
return &ThreadSafeAgent{Agent: agent}, nil
|
||||
}
|
||||
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
// Defensive check to prevent nil pointer dereference
|
||||
// This can happen during sleep/wake transitions or memory corruption scenarios
|
||||
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
|
||||
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
|
||||
agent := a.Agent
|
||||
if agent == nil {
|
||||
log.Warnf("ICE agent is nil during close, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func GenerateICECredentials() (string, string, error) {
|
||||
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
|
||||
if err != nil {
|
||||
|
||||
@@ -1145,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
return d.wgIface.FullStats()
|
||||
}
|
||||
|
||||
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
|
||||
// and updates the cached peer states. This ensures accurate handshake times and
|
||||
// transfer statistics in status reports without running full health probes.
|
||||
func (d *Status) RefreshWireGuardStats() error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if d.wgIface == nil {
|
||||
return nil // silently skip if interface not set
|
||||
}
|
||||
|
||||
stats, err := d.wgIface.FullStats()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get wireguard stats: %w", err)
|
||||
}
|
||||
|
||||
// Update each peer's WireGuard statistics
|
||||
for _, peerStats := range stats.Peers {
|
||||
peerState, ok := d.peers[peerStats.PublicKey]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
peerState.LastWireguardHandshake = peerStats.LastHandshake
|
||||
peerState.BytesRx = peerStats.RxBytes
|
||||
peerState.BytesTx = peerStats.TxBytes
|
||||
d.peers[peerStats.PublicKey] = peerState
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type EventQueue struct {
|
||||
maxSize int
|
||||
events []*proto.SystemEvent
|
||||
|
||||
@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
w.agentDialerCancel()
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
if w.agent != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig PKCEAuthProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
authFlow := PKCEAuthorizationFlow{
|
||||
ProviderConfig: PKCEAuthProviderConfig{
|
||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||
},
|
||||
}
|
||||
|
||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return authFlow, nil
|
||||
}
|
||||
|
||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
}
|
||||
|
||||
if config.AdminURL == nil {
|
||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
||||
log.Infof("using default Admin URL %s", DefaultAdminURL)
|
||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLog = slog.LevelInfo
|
||||
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
|
||||
)
|
||||
|
||||
func hashRosenpassKey(key []byte) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write(key)
|
||||
@@ -45,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
||||
}
|
||||
|
||||
rpKeyHash := hashRosenpassKey(public)
|
||||
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||
}
|
||||
|
||||
@@ -101,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
|
||||
|
||||
func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
Level: getLogLevel(),
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
|
||||
cfg := rp.Config{Logger: logger}
|
||||
@@ -133,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getLogLevel() slog.Level {
|
||||
level, ok := os.LookupEnv(defaultLogLevelVar)
|
||||
if !ok {
|
||||
return defaultLog
|
||||
}
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "info":
|
||||
return slog.LevelInfo
|
||||
case "warn":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
default:
|
||||
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
|
||||
return defaultLog
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) OnDisconnected(peerKey string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
}
|
||||
|
||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||
var once sync.Once
|
||||
var wgIface *net.Interface
|
||||
toInterface := func() *net.Interface {
|
||||
once.Do(func() {
|
||||
wgIface = m.wgInterface.ToInterface()
|
||||
})
|
||||
return wgIface
|
||||
}
|
||||
|
||||
m.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ struct{}) error {
|
||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@ package systemops
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
func formatBSDFlags(flags int) string {
|
||||
var flagStrs []string
|
||||
|
||||
if flags&syscall.RTF_UP != 0 {
|
||||
if flags&unix.RTF_UP != 0 {
|
||||
flagStrs = append(flagStrs, "U")
|
||||
}
|
||||
if flags&syscall.RTF_GATEWAY != 0 {
|
||||
if flags&unix.RTF_GATEWAY != 0 {
|
||||
flagStrs = append(flagStrs, "G")
|
||||
}
|
||||
if flags&syscall.RTF_HOST != 0 {
|
||||
if flags&unix.RTF_HOST != 0 {
|
||||
flagStrs = append(flagStrs, "H")
|
||||
}
|
||||
if flags&syscall.RTF_REJECT != 0 {
|
||||
if flags&unix.RTF_REJECT != 0 {
|
||||
flagStrs = append(flagStrs, "R")
|
||||
}
|
||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
||||
if flags&unix.RTF_DYNAMIC != 0 {
|
||||
flagStrs = append(flagStrs, "D")
|
||||
}
|
||||
if flags&syscall.RTF_MODIFIED != 0 {
|
||||
if flags&unix.RTF_MODIFIED != 0 {
|
||||
flagStrs = append(flagStrs, "M")
|
||||
}
|
||||
if flags&syscall.RTF_STATIC != 0 {
|
||||
if flags&unix.RTF_STATIC != 0 {
|
||||
flagStrs = append(flagStrs, "S")
|
||||
}
|
||||
if flags&syscall.RTF_LLINFO != 0 {
|
||||
if flags&unix.RTF_LLINFO != 0 {
|
||||
flagStrs = append(flagStrs, "L")
|
||||
}
|
||||
if flags&syscall.RTF_LOCAL != 0 {
|
||||
if flags&unix.RTF_LOCAL != 0 {
|
||||
flagStrs = append(flagStrs, "l")
|
||||
}
|
||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
||||
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||
flagStrs = append(flagStrs, "B")
|
||||
}
|
||||
if flags&syscall.RTF_CLONING != 0 {
|
||||
if flags&unix.RTF_CLONING != 0 {
|
||||
flagStrs = append(flagStrs, "C")
|
||||
}
|
||||
if flags&syscall.RTF_WASCLONED != 0 {
|
||||
if flags&unix.RTF_WASCLONED != 0 {
|
||||
flagStrs = append(flagStrs, "W")
|
||||
}
|
||||
if flags&unix.RTF_PROTO1 != 0 {
|
||||
flagStrs = append(flagStrs, "1")
|
||||
}
|
||||
if flags&unix.RTF_PROTO2 != 0 {
|
||||
flagStrs = append(flagStrs, "2")
|
||||
}
|
||||
if flags&unix.RTF_PROTO3 != 0 {
|
||||
flagStrs = append(flagStrs, "3")
|
||||
}
|
||||
|
||||
if len(flagStrs) == 0 {
|
||||
return "-"
|
||||
|
||||
@@ -4,17 +4,18 @@ package systemops
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
if routeMessageFlags&unix.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
func formatBSDFlags(flags int) string {
|
||||
var flagStrs []string
|
||||
|
||||
if flags&syscall.RTF_UP != 0 {
|
||||
if flags&unix.RTF_UP != 0 {
|
||||
flagStrs = append(flagStrs, "U")
|
||||
}
|
||||
if flags&syscall.RTF_GATEWAY != 0 {
|
||||
if flags&unix.RTF_GATEWAY != 0 {
|
||||
flagStrs = append(flagStrs, "G")
|
||||
}
|
||||
if flags&syscall.RTF_HOST != 0 {
|
||||
if flags&unix.RTF_HOST != 0 {
|
||||
flagStrs = append(flagStrs, "H")
|
||||
}
|
||||
if flags&syscall.RTF_REJECT != 0 {
|
||||
if flags&unix.RTF_REJECT != 0 {
|
||||
flagStrs = append(flagStrs, "R")
|
||||
}
|
||||
if flags&syscall.RTF_DYNAMIC != 0 {
|
||||
if flags&unix.RTF_DYNAMIC != 0 {
|
||||
flagStrs = append(flagStrs, "D")
|
||||
}
|
||||
if flags&syscall.RTF_MODIFIED != 0 {
|
||||
if flags&unix.RTF_MODIFIED != 0 {
|
||||
flagStrs = append(flagStrs, "M")
|
||||
}
|
||||
if flags&syscall.RTF_STATIC != 0 {
|
||||
if flags&unix.RTF_STATIC != 0 {
|
||||
flagStrs = append(flagStrs, "S")
|
||||
}
|
||||
if flags&syscall.RTF_LLINFO != 0 {
|
||||
if flags&unix.RTF_LLINFO != 0 {
|
||||
flagStrs = append(flagStrs, "L")
|
||||
}
|
||||
if flags&syscall.RTF_LOCAL != 0 {
|
||||
if flags&unix.RTF_LOCAL != 0 {
|
||||
flagStrs = append(flagStrs, "l")
|
||||
}
|
||||
if flags&syscall.RTF_BLACKHOLE != 0 {
|
||||
if flags&unix.RTF_BLACKHOLE != 0 {
|
||||
flagStrs = append(flagStrs, "B")
|
||||
}
|
||||
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
|
||||
if flags&unix.RTF_PROTO1 != 0 {
|
||||
flagStrs = append(flagStrs, "1")
|
||||
}
|
||||
if flags&unix.RTF_PROTO2 != 0 {
|
||||
flagStrs = append(flagStrs, "2")
|
||||
}
|
||||
if flags&unix.RTF_PROTO3 != 0 {
|
||||
flagStrs = append(flagStrs, "3")
|
||||
}
|
||||
|
||||
if len(flagStrs) == 0 {
|
||||
return "-"
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
||||
return false, errors.New("not supported on mobile platforms")
|
||||
}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
log.Debugf("Interface monitor: skipped in netstack mode")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ifaceName == "" {
|
||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||
return false, errors.New("empty interface name")
|
||||
|
||||
@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
||||
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
|
||||
return true // Assume login is required if we can't create auth client
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||
// If the check fails, assume login is required to be safe
|
||||
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
|
||||
|
||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||
go func() {
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||
return
|
||||
}
|
||||
jwtToken := tokenInfo.GetTokenToUse()
|
||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
||||
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
|
||||
return
|
||||
}
|
||||
defer authClient.Close()
|
||||
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
|
||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,13 +7,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
// LoginSync performs a synchronous login check without UI interaction
|
||||
// Used for background VPN connection where user should already be authenticated
|
||||
func (a *Auth) LoginSync() error {
|
||||
var needsLogin bool
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
|
||||
return fmt.Errorf("not authenticated")
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return fmt.Errorf("authentication error: %v", err)
|
||||
}
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||
var needsLogin bool
|
||||
|
||||
// Create context with device name if provided
|
||||
ctx := a.ctx
|
||||
if deviceName != "" {
|
||||
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
}
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
||||
return
|
||||
})
|
||||
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(ctx, func() error {
|
||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, isAuthError := authClient.Login(ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return fmt.Errorf("authentication error: %v", err)
|
||||
}
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
||||
|
||||
const authInfoRequestTimeout = 30 * time.Second
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
// GetConfigJSON returns the current config as a JSON string.
|
||||
// This can be used by the caller to persist the config via alternative storage
|
||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
|
||||
@@ -34,6 +34,18 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}()
|
||||
}
|
||||
|
||||
// Prepare refresh callback for health probes
|
||||
var refreshStatus func()
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
debug.GeneratorDependencies{
|
||||
InternalConfig: s.config,
|
||||
@@ -41,6 +53,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: s.logFile,
|
||||
CPUProfile: cpuProfileData,
|
||||
RefreshStatus: refreshStatus,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
Anonymize: req.GetAnonymize(),
|
||||
|
||||
@@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
var status internal.StatusType
|
||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
||||
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||
if err != nil {
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
log.Errorf("failed to create auth client: %v", err)
|
||||
return internal.StatusLoginFailed, err
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
var status internal.StatusType
|
||||
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
log.Warnf("failed login: %v", err)
|
||||
status = internal.StatusNeedsLogin
|
||||
} else {
|
||||
@@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.oauthAuthFlow.waitCancel()
|
||||
}
|
||||
|
||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
waitCTX, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
s.mutex.Lock()
|
||||
@@ -1327,6 +1333,10 @@ func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if engine.RunHealthProbes(waitForProbeResult) {
|
||||
s.lastProbe = time.Now()
|
||||
}
|
||||
} else {
|
||||
if err := s.statusRecorder.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||
// Create a backend session to mirror the client's session request.
|
||||
// This keeps the connection alive on the server side while port forwarding channels operate.
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
<-session.Context().Done()
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
|
||||
if err := session.Exit(0); err != nil {
|
||||
log.Debugf("session exit: %v", err)
|
||||
if err := serverSession.Shell(); err != nil {
|
||||
log.Debugf("start shell: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- serverSession.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
log.Debugf("shell session: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
// handleExecution executes an SSH command or shell with privilege validation
|
||||
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
||||
|
||||
defer cleanup()
|
||||
|
||||
ptyReq, _, _ := session.Pty()
|
||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, nil, errors.New("no user in privilege result")
|
||||
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
|
||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||
if hasPty && !s.suSupportsPty {
|
||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
logger.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
return cmd, func() {}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,17 +15,17 @@ import (
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
return nil, nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
||||
return isUtilLinux
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
// createSuCommand creates a command using su - for privilege switching.
|
||||
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
}
|
||||
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
args := []string{"-l"}
|
||||
args := []string{"-"}
|
||||
if hasPty && s.suSupportsPty {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
args = append(args, localUser.Username, "-c", command)
|
||||
args = append(args, localUser.Username)
|
||||
|
||||
command := session.RawCommand()
|
||||
if command != "" {
|
||||
args = append(args, "-c", command)
|
||||
}
|
||||
|
||||
logger.Debugf("creating su command: %s %v", suPath, args)
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string.
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-l"}
|
||||
return []string{shell}
|
||||
}
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
return []string{shell, "-c", cmdString}
|
||||
}
|
||||
|
||||
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
|
||||
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, shell, args[1:]...)
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return
|
||||
} else {
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
// Close PTY to unblock io.Copy goroutines
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close after completion: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,32 +20,32 @@ import (
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(logger, username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
logger.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
||||
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
|
||||
}
|
||||
|
||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
logger.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
if privilegeResult.User == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
logger.Info("starting interactive shell")
|
||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
|
||||
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
logger.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
||||
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
logger.Error("no command specified for PTY execution")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
||||
}
|
||||
|
||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Command: session.RawCommand(),
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
|
||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
||||
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
|
||||
logger.Errorf("ConPty execution failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
|
||||
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -23,25 +26,67 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||
// propagate exit codes.
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
runTestExecutor()
|
||||
return
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// runTestExecutor emulates the netbird executor for tests.
|
||||
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||
func runTestExecutor() {
|
||||
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||
|
||||
shell := "/bin/sh"
|
||||
var command string
|
||||
for i := 3; i < len(os.Args); i++ {
|
||||
switch os.Args[i] {
|
||||
case "--shell":
|
||||
if i+1 < len(os.Args) {
|
||||
shell = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--cmd":
|
||||
if i+1 < len(os.Args) {
|
||||
command = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if command == "" {
|
||||
cmd = exec.Command(shell)
|
||||
} else {
|
||||
cmd = exec.Command(shell, "-c", command)
|
||||
}
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
os.Exit(exitErr.ExitCode())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||
// This ensures our implementation matches OpenSSH behavior for:
|
||||
// - ssh host command (no PTY - default when no TTY)
|
||||
// - ssh -T host command (explicit no PTY)
|
||||
// - ssh -t host command (force PTY)
|
||||
// - ssh -T host (no PTY shell - our implementation)
|
||||
func TestSSHPtyModes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
baseArgs := []string{
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "BatchMode=yes",
|
||||
}
|
||||
|
||||
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "no_pty_default")
|
||||
})
|
||||
|
||||
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "explicit_no_pty")
|
||||
})
|
||||
|
||||
t.Run("command_force_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "force_pty")
|
||||
})
|
||||
|
||||
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||
assert.NoError(t, err, "write echo command")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err = stdin.Write([]byte("exit 0\n"))
|
||||
assert.NoError(t, err, "write exit command")
|
||||
}()
|
||||
|
||||
output, _ := io.ReadAll(stdout)
|
||||
err = cmd.Wait()
|
||||
|
||||
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "Command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "PTY command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||
})
|
||||
|
||||
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||
})
|
||||
|
||||
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||
assert.Contains(t, string(output), "stdout_msg")
|
||||
assert.Contains(t, string(output), "stderr_msg")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
|
||||
}
|
||||
|
||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||
type PrivilegeDropper struct{}
|
||||
type PrivilegeDropper struct {
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||
|
||||
// NewPrivilegeDropper creates a new privilege dropper
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||
pd := &PrivilegeDropper{}
|
||||
for _, opt := range opts {
|
||||
opt(pd)
|
||||
}
|
||||
return pd
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the PrivilegeDropper
|
||||
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||
return func(pd *PrivilegeDropper) {
|
||||
pd.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// log returns the logger, falling back to standard logger if none set
|
||||
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||
if pd.logger != nil {
|
||||
return pd.logger
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||
}
|
||||
|
||||
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
|
||||
|
||||
var execCmd *exec.Cmd
|
||||
if config.Command == "" {
|
||||
os.Exit(ExitCodeSuccess)
|
||||
execCmd = exec.CommandContext(ctx, config.Shell)
|
||||
} else {
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
}
|
||||
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
|
||||
execCmd.Stdin = os.Stdin
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
if config.Command == "" {
|
||||
log.Tracef("executing login shell: %s", execCmd.Path)
|
||||
} else {
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
}
|
||||
if err := execCmd.Run(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
|
||||
@@ -28,22 +28,45 @@ const (
|
||||
)
|
||||
|
||||
type WindowsExecutorConfig struct {
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Interactive bool
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
}
|
||||
|
||||
type PrivilegeDropper struct{}
|
||||
type PrivilegeDropper struct {
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||
|
||||
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||
pd := &PrivilegeDropper{}
|
||||
for _, opt := range opts {
|
||||
opt(pd)
|
||||
}
|
||||
return pd
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the PrivilegeDropper
|
||||
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||
return func(pd *PrivilegeDropper) {
|
||||
pd.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// log returns the logger, falling back to standard logger if none set
|
||||
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||
if pd.logger != nil {
|
||||
return pd.logger
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -56,7 +79,6 @@ const (
|
||||
|
||||
// Common error messages
|
||||
commandFlag = "-Command"
|
||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
||||
convertUsernameError = "convert username to UTF16: %w"
|
||||
convertDomainError = "convert domain to UTF16: %w"
|
||||
)
|
||||
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
||||
shellArgs = []string{shell}
|
||||
}
|
||||
|
||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
|
||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
pd := NewPrivilegeDropper(WithLogger(logger))
|
||||
isDomainUser := !pd.isLocalUser(domain)
|
||||
|
||||
lsaHandle, err := initializeLsaConnection()
|
||||
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
}
|
||||
|
||||
// buildUserCpn constructs the user principal name
|
||||
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
|
||||
}
|
||||
|
||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
if isDomainUser {
|
||||
return prepareDomainS4ULogon(username, domain)
|
||||
return prepareDomainS4ULogon(logger, username, domain)
|
||||
}
|
||||
return prepareLocalS4ULogon(username)
|
||||
return prepareLocalS4ULogon(logger, username)
|
||||
}
|
||||
|
||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
upn, err := lookupPrincipalName(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
|
||||
upnUtf16, err := windows.UTF16FromString(upn)
|
||||
if err != nil {
|
||||
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
|
||||
}
|
||||
|
||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
|
||||
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
|
||||
usernameUtf16, err := windows.UTF16FromString(username)
|
||||
if err != nil {
|
||||
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
}
|
||||
|
||||
// performS4ULogon executes the S4U logon operation
|
||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
var tokenSource tokenSource
|
||||
copy(tokenSource.SourceName[:], "netbird")
|
||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||
log.Debugf("AllocateLocallyUniqueId failed")
|
||||
logger.Debugf("AllocateLocallyUniqueId failed")
|
||||
}
|
||||
|
||||
originName := newLsaString("netbird")
|
||||
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
||||
|
||||
if profile != 0 {
|
||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||
}
|
||||
|
||||
log.Debugf("created S4U %s token for user %s",
|
||||
logger.Debugf("created S4U %s token for user %s",
|
||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||
return token, nil
|
||||
}
|
||||
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
||||
|
||||
// authenticateLocalUser handles authentication for local users
|
||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, ".")
|
||||
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(pd.log(), username, ".")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||
}
|
||||
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
|
||||
|
||||
// authenticateDomainUser handles authentication for domain users
|
||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, domain)
|
||||
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(pd.log(), username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||
}
|
||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
||||
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Debugf("close impersonation token: %v", err)
|
||||
pd.log().Debugf("close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
// createSuCommand creates a command using su - for privilege switching (Windows stub).
|
||||
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
return nil, fmt.Errorf("su command not available on Windows")
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
||||
return s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
||||
func (s *Server) isPortForwardingEnabled() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// parseTcpipForwardRequest parses the SSH request payload
|
||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||
var payload tcpipForwardMsg
|
||||
|
||||
@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
sessions = append(sessions, info)
|
||||
}
|
||||
|
||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
||||
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
|
||||
for key, connState := range s.connections {
|
||||
remoteAddr := string(key)
|
||||
if reportedAddrs[remoteAddr] {
|
||||
|
||||
@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
||||
func TestServer_NonPtyShellSession(t *testing.T) {
|
||||
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
name string
|
||||
allowLocalForwarding bool
|
||||
allowRemoteForwarding bool
|
||||
expectAllowed bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "session_allowed_with_local_forwarding",
|
||||
name: "shell_with_local_forwarding_enabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: false,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
||||
},
|
||||
{
|
||||
name: "session_allowed_with_remote_forwarding",
|
||||
name: "shell_with_remote_forwarding_enabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: true,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
||||
},
|
||||
{
|
||||
name: "session_allowed_with_both",
|
||||
name: "shell_with_both_forwarding_enabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: true,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
||||
},
|
||||
{
|
||||
name: "session_denied_without_forwarding",
|
||||
name: "shell_with_forwarding_disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: false,
|
||||
expectAllowed: false,
|
||||
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
// Connect to the server without requesting PTY or command
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
_ = client.Close()
|
||||
}()
|
||||
|
||||
// Execute a command without PTY - this simulates ssh -T with no command
|
||||
// The server should either allow it (port forwarding enabled) or reject it
|
||||
output, err := client.ExecuteCommand(ctx, "")
|
||||
if tt.expectAllowed {
|
||||
// When allowed, the session stays open until cancelled
|
||||
// ExecuteCommand with empty command should return without error
|
||||
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
||||
assert.NotContains(t, output, "port forwarding is disabled",
|
||||
"Output should not contain port forwarding disabled message")
|
||||
} else if err != nil {
|
||||
// When denied, we expect an error message about port forwarding being disabled
|
||||
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
||||
"Should get port forwarding disabled message")
|
||||
}
|
||||
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
|
||||
// Should always succeed regardless of port forwarding settings
|
||||
_, err = client.ExecuteCommand(ctx, "")
|
||||
assert.NoError(t, err, "Non-PTY shell session should be allowed")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
} else {
|
||||
// Test Unix shell behavior
|
||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Equal(t, "-l", args[1])
|
||||
assert.Equal(t, "-c", args[2])
|
||||
assert.Equal(t, "echo test", args[3])
|
||||
assert.Equal(t, "-c", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
|
||||
args = server.getShellCommandArgs("/bin/sh", "")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Len(t, args, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
switch {
|
||||
case isPty && hasCommand:
|
||||
// ssh -t <host> <cmd> - Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
||||
case isPty:
|
||||
// ssh <host> - Pty interactive session (login)
|
||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
||||
case hasCommand:
|
||||
// ssh <host> <cmd> - non-Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, nil)
|
||||
default:
|
||||
// ssh -T (or ssh -N) - no PTY, no command
|
||||
s.handleNonInteractiveSession(logger, session)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
||||
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
||||
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
||||
s.updateSessionType(session, cmdNonInteractive)
|
||||
|
||||
if !s.isPortForwardingEnabled() {
|
||||
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
<-session.Context().Done()
|
||||
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, state := range s.sessions {
|
||||
if state.session == session {
|
||||
state.sessionType = sessionType
|
||||
return
|
||||
}
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
|
||||
} else {
|
||||
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
|
||||
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handlePty is not supported on JS/WASM
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
// handlePtyLogin is not supported on JS/WASM
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
|
||||
@@ -8,19 +8,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StartTestServer starts the SSH server and returns the address it's listening on.
|
||||
func StartTestServer(t *testing.T, server *Server) string {
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Use port 0 to let the OS assign a free port
|
||||
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get the actual listening address from the server
|
||||
actualAddr := server.Addr()
|
||||
if actualAddr == nil {
|
||||
errChan <- fmt.Errorf("server started but no listener address available")
|
||||
|
||||
@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
||||
|
||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||
// Returns the command and a cleanup function (no-op on Unix).
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
config := ExecutorConfig{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
|
||||
@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
|
||||
|
||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
username, _ := s.parseUsername(localUser.Username)
|
||||
if err := validateUsername(username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||
}
|
||||
|
||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||
return s.createUserSwitchCommand(logger, session, localUser)
|
||||
}
|
||||
|
||||
// createUserSwitchCommand creates a command with Windows user switching.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
@@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
||||
}
|
||||
|
||||
config := WindowsExecutorConfig{
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Interactive: interactive || (rawCmd == ""),
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
}
|
||||
|
||||
dropper := NewPrivilegeDropper()
|
||||
dropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
||||
cleanup := func() {
|
||||
if token != 0 {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
logger.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user