mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 23:14:11 -04:00
Compare commits
91 Commits
wasm-debug
...
move-licen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24b66fb406 | ||
|
|
9378b6b0a3 | ||
|
|
3779a3385f | ||
|
|
b5d75ad9c4 | ||
|
|
8db91abfdf | ||
|
|
6f817cad6d | ||
|
|
e3bb8c1b7b | ||
|
|
107066fa3d | ||
|
|
a7a85d4dc8 | ||
|
|
576b4a779c | ||
|
|
e6854dfd99 | ||
|
|
6f14134988 | ||
|
|
4fd64379da | ||
|
|
c20202a6c3 | ||
|
|
4386a21956 | ||
|
|
5882daf5d9 | ||
|
|
11d71e6e22 | ||
|
|
4dadcfd9bd | ||
|
|
34b55c600e | ||
|
|
316c0afa9a | ||
|
|
cf97799db8 | ||
|
|
4d297205c3 | ||
|
|
559f6aeeaf | ||
|
|
7216c201da | ||
|
|
4d89d0f115 | ||
|
|
610c880ec9 | ||
|
|
19adcb5f63 | ||
|
|
f3d31698da | ||
|
|
d9efe4e944 | ||
|
|
7e0bbaaa3c | ||
|
|
b3c7b3c7b2 | ||
|
|
66483ab48d | ||
|
|
5272fc2b18 | ||
|
|
4c53372815 | ||
|
|
79d28b71ee | ||
|
|
77a352763d | ||
|
|
cdd5c6c005 | ||
|
|
b1a9242c98 | ||
|
|
b43ef4f17b | ||
|
|
758a97c352 | ||
|
|
d93b7c2f38 | ||
|
|
fa893aa0a4 | ||
|
|
ac7120871b | ||
|
|
9a7daa132e | ||
|
|
cdded8c22e | ||
|
|
e4e0b8fff9 | ||
|
|
a4b067553d | ||
|
|
088956645f | ||
|
|
aa30b7afe8 | ||
|
|
f1bb4d2ac3 | ||
|
|
982841e25b | ||
|
|
a476b8d12f | ||
|
|
a21f924b26 | ||
|
|
9e51d2e8fb | ||
|
|
3e490d974c | ||
|
|
04bb314426 | ||
|
|
6e15882c11 | ||
|
|
76f9e11b29 | ||
|
|
612de2c784 | ||
|
|
1fdde66c31 | ||
|
|
5970591d24 | ||
|
|
0d5408baec | ||
|
|
96084e3a02 | ||
|
|
4bbca28eb6 | ||
|
|
279b77dee0 | ||
|
|
9d1554f9f7 | ||
|
|
f56075ca15 | ||
|
|
6ed846ae29 | ||
|
|
520f2cfdb4 | ||
|
|
0f79a8942d | ||
|
|
5299e9fda3 | ||
|
|
11bdf5b3a5 | ||
|
|
5fc95d4a0c | ||
|
|
c7884039b8 | ||
|
|
26fc32f1be | ||
|
|
a79cb1c11b | ||
|
|
306d75fe1a | ||
|
|
9468e69c8c | ||
|
|
f51ce7cee5 | ||
|
|
d47c6b624e | ||
|
|
471f90e8db | ||
|
|
1a3b04d2fe | ||
|
|
51b9e93eb9 | ||
|
|
2952669e97 | ||
|
|
7cd44a9a3c | ||
|
|
8684981b57 | ||
|
|
8e94d85d14 | ||
|
|
631b77dc3c | ||
|
|
50ac3d437e | ||
|
|
49bbd90557 | ||
|
|
bb74e903cd |
@@ -1,15 +1,15 @@
|
||||
FROM golang:1.25-bookworm
|
||||
FROM golang:1.23-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
gettext-base=0.21-12 \
|
||||
iptables=1.8.9-2 \
|
||||
libgl1-mesa-dev=22.3.6-1+deb12u1 \
|
||||
xorg-dev=1:7.7+23 \
|
||||
libayatana-appindicator3-dev=0.5.92-1 \
|
||||
gettext-base=0.21-4 \
|
||||
iptables=1.8.7-1 \
|
||||
libgl1-mesa-dev=20.3.5-1 \
|
||||
xorg-dev=1:7.7+22 \
|
||||
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& go install -v golang.org/x/tools/gopls@latest
|
||||
&& go install -v golang.org/x/tools/gopls@v0.18.1
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Running pre-push hook..."
|
||||
if ! make lint; then
|
||||
echo ""
|
||||
echo "Hint: To push without verification, run:"
|
||||
echo " git push --no-verify"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All checks passed!"
|
||||
116
.github/workflows/check-license-dependencies.yml
vendored
116
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,108 +3,40 @@ name: Check License Dependencies
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
name: Check Internal AGPL Dependencies
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All internal license dependencies are clean"
|
||||
fi
|
||||
|
||||
check-external-licenses:
|
||||
name: Check External GPL/AGPL Licenses
|
||||
check-dependencies:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All license dependencies are clean"
|
||||
fi
|
||||
|
||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,14 +15,13 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
|
||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
@@ -106,15 +106,15 @@ jobs:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -151,15 +151,15 @@ jobs:
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
id: go-env
|
||||
run: |
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.25-alpine \
|
||||
golang:1.23-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
@@ -220,15 +220,15 @@ jobs:
|
||||
raceFlag: "-race"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
@@ -259,7 +259,7 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
-timeout 10m ./relay/... ./shared/relay/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
@@ -270,15 +270,15 @@ jobs:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
@@ -321,15 +321,15 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -408,16 +408,15 @@ jobs:
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -498,15 +497,15 @@ jobs:
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -562,15 +561,15 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/setup-go@v5
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
|
||||
9
.github/workflows/golangci-lint.yml
vendored
9
.github/workflows/golangci-lint.yml
vendored
@@ -46,16 +46,13 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
args: --timeout=12m
|
||||
args: --timeout=12m --out-format colored-line-number
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
@@ -56,9 +56,9 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS netbird lib
|
||||
|
||||
104
.github/workflows/release.yml
vendored
104
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.0"
|
||||
SIGN_PIPE_VER: "v0.0.23"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -19,102 +19,8 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release_freebsd_port:
|
||||
name: "FreeBSD Port / Build & Test"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
run: |
|
||||
if ls netbird-*.diff 1> /dev/null 2>&1; then
|
||||
echo "diff_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "diff_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No diff file generated (port may already be up to date)"
|
||||
fi
|
||||
|
||||
- name: Extract version
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
|
||||
# Clone ports tree (shallow, only what we need)
|
||||
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
|
||||
cd /usr/ports
|
||||
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin
|
||||
|
||||
# Find the diff file
|
||||
echo "Finding diff file..."
|
||||
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
|
||||
echo "Found: $DIFF_FILE"
|
||||
|
||||
if [[ -z "$DIFF_FILE" ]]; then
|
||||
echo "ERROR: Could not find diff file"
|
||||
find ~ -name "*.diff" -type f 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
|
||||
cd /usr/ports
|
||||
patch -p1 -V none < "$DIFF_FILE"
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
./netbird-*-issue.txt
|
||||
./netbird-*.diff
|
||||
retention-days: 30
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest-m
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -134,7 +40,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -230,7 +136,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -294,7 +200,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
|
||||
@@ -67,13 +67,10 @@ jobs:
|
||||
- name: Install curl
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -83,6 +80,9 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup MySQL privileges
|
||||
if: matrix.store == 'mysql'
|
||||
run: |
|
||||
@@ -243,7 +243,6 @@ jobs:
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose logs
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
|
||||
|
||||
|
||||
21
.github/workflows/wasm-build-validation.yml
vendored
21
.github/workflows/wasm-build-validation.yml
vendored
@@ -14,27 +14,26 @@ jobs:
|
||||
js_lint:
|
||||
name: "JS / Lint"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOOS: js
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
working-directory: ./client
|
||||
skip-pkg-cache: true
|
||||
skip-build-cache: true
|
||||
- name: Run golangci-lint for WASM
|
||||
run: |
|
||||
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||
continue-on-error: true
|
||||
|
||||
js_build:
|
||||
@@ -46,7 +45,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
go-version: "1.23.x"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
env:
|
||||
@@ -61,8 +60,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 57671680 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||
if [ ${SIZE} -gt 52428800 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -31,4 +31,3 @@ infrastructure_files/setup-*.env
|
||||
.DS_Store
|
||||
vendor/
|
||||
/netbird
|
||||
client/netbird-electron/
|
||||
|
||||
257
.golangci.yaml
257
.golangci.yaml
@@ -1,124 +1,139 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- bodyclose
|
||||
- dupword
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- forbidigo
|
||||
- gocritic
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- mirror
|
||||
- misspell
|
||||
- nilerr
|
||||
- nilnil
|
||||
- predeclared
|
||||
- revive
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- unused
|
||||
- wastedassign
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: false
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
gosec:
|
||||
includes:
|
||||
- G101
|
||||
- G103
|
||||
- G104
|
||||
- G106
|
||||
- G108
|
||||
- G109
|
||||
- G110
|
||||
- G111
|
||||
- G201
|
||||
- G202
|
||||
- G203
|
||||
- G301
|
||||
- G302
|
||||
- G303
|
||||
- G304
|
||||
- G305
|
||||
- G306
|
||||
- G307
|
||||
- G403
|
||||
- G502
|
||||
- G503
|
||||
- G504
|
||||
- G601
|
||||
- G602
|
||||
govet:
|
||||
enable:
|
||||
- nilness
|
||||
enable-all: false
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
arguments:
|
||||
- checkPrivateReceivers
|
||||
- sayRepetitiveInsteadOfStutters
|
||||
severity: warning
|
||||
disabled: false
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 6m
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: false
|
||||
|
||||
gosec:
|
||||
includes:
|
||||
- G101 # Look for hard coded credentials
|
||||
#- G102 # Bind to all interfaces
|
||||
- G103 # Audit the use of unsafe block
|
||||
- G104 # Audit errors not checked
|
||||
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
|
||||
#- G107 # Url provided to HTTP request as taint input
|
||||
- G108 # Profiling endpoint automatically exposed on /debug/pprof
|
||||
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
|
||||
- G110 # Potential DoS vulnerability via decompression bomb
|
||||
- G111 # Potential directory traversal
|
||||
#- G112 # Potential slowloris attack
|
||||
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
|
||||
#- G114 # Use of net/http serve function that has no support for setting timeouts
|
||||
- G201 # SQL query construction using format string
|
||||
- G202 # SQL query construction using string concatenation
|
||||
- G203 # Use of unescaped data in HTML templates
|
||||
#- G204 # Audit use of command execution
|
||||
- G301 # Poor file permissions used when creating a directory
|
||||
- G302 # Poor file permissions used with chmod
|
||||
- G303 # Creating tempfile using a predictable path
|
||||
- G304 # File path provided as taint input
|
||||
- G305 # File traversal when extracting zip/tar archive
|
||||
- G306 # Poor file permissions used when writing to a new file
|
||||
- G307 # Poor file permissions used when creating a file with os.Create
|
||||
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
|
||||
#- G402 # Look for bad TLS connection settings
|
||||
- G403 # Ensure minimum RSA key length of 2048 bits
|
||||
#- G404 # Insecure random number source (rand)
|
||||
#- G501 # Import blocklist: crypto/md5
|
||||
- G502 # Import blocklist: crypto/des
|
||||
- G503 # Import blocklist: crypto/rc4
|
||||
- G504 # Import blocklist: net/http/cgi
|
||||
#- G505 # Import blocklist: crypto/sha1
|
||||
- G601 # Implicit memory aliasing of items from a range statement
|
||||
- G602 # Slice access out of bounds
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: false
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- linters:
|
||||
- forbidigo
|
||||
path: management/cmd/root\.go
|
||||
- linters:
|
||||
- forbidigo
|
||||
path: signal/cmd/root\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: sharedsock/filter\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
path: test\.go
|
||||
- linters:
|
||||
- nilnil
|
||||
path: mock\.go
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: grpc.DialContext is deprecated
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: grpc.WithBlock is deprecated
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1001"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1008"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1012"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
- name: exported
|
||||
severity: warning
|
||||
disabled: false
|
||||
arguments:
|
||||
- "checkPrivateReceivers"
|
||||
- "sayRepetitiveInsteadOfStutters"
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||
- gosimple # specializes in simplifying a code
|
||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # detects when assignments to existing variables are not used
|
||||
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
||||
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # checks for unused constants, variables, functions and types
|
||||
## disable by default but the have interesting results so lets add them
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- dupword # dupword checks for duplicate words in the source code
|
||||
- durationcheck # durationcheck checks for two durations multiplied together
|
||||
- forbidigo # forbidigo forbids identifiers
|
||||
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
||||
- gosec # inspects source code for security problems
|
||||
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
|
||||
- misspell # misspess finds commonly misspelled English words in comments
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 5
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
exclude-rules:
|
||||
# allow fmt
|
||||
- path: management/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: signal/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: sharedsock/filter\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: client/firewall/iptables/rule\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: test\.go
|
||||
linters:
|
||||
- mirror
|
||||
- gosec
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
# Exclude specific deprecation warnings for grpc methods
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.DialContext is deprecated"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.WithBlock is deprecated"
|
||||
|
||||
@@ -713,10 +713,8 @@ checksum:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
|
||||
@@ -136,14 +136,6 @@ checked out and set up:
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
6. Configure Git hooks for automatic linting:
|
||||
|
||||
```bash
|
||||
make setup-hooks
|
||||
```
|
||||
|
||||
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||
|
||||
### Dev Container Support
|
||||
|
||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||
|
||||
27
Makefile
27
Makefile
@@ -1,27 +0,0 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
$(GOLANGCI_LINT):
|
||||
@echo "Installing golangci-lint..."
|
||||
@mkdir -p ./bin
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# Lint only changed files (fast, for pre-push)
|
||||
lint: $(GOLANGCI_LINT)
|
||||
@echo "Running lint on changed files..."
|
||||
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||
|
||||
# Lint entire codebase (slow, matches CI)
|
||||
lint-all: $(GOLANGCI_LINT)
|
||||
@echo "Running lint on all files..."
|
||||
@$(GOLANGCI_LINT) run --timeout=12m
|
||||
|
||||
# Just install the linter
|
||||
lint-install: $(GOLANGCI_LINT)
|
||||
|
||||
# Setup git hooks for all developers
|
||||
setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
11
README.md
11
README.md
@@ -38,11 +38,6 @@
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
@@ -90,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
@@ -103,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
@@ -118,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
@@ -4,13 +4,10 @@ package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -19,13 +16,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -59,6 +53,7 @@ func init() {
|
||||
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter device.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
@@ -72,11 +67,12 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
execWorkaround(androidSDKVersion)
|
||||
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
uiVersion: uiVersion,
|
||||
tunAdapter: tunAdapter,
|
||||
@@ -88,16 +84,10 @@ func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAd
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
|
||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: cfgFile,
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -117,29 +107,23 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
auth := NewAuthWithConfig(ctx, cfg)
|
||||
err = auth.login(urlOpener, isAndroidTV)
|
||||
err = auth.login(urlOpener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
|
||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: cfgFile,
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -157,8 +141,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -172,19 +156,6 @@ func (c *Client) Stop() {
|
||||
c.ctxCancel()
|
||||
}
|
||||
|
||||
func (c *Client) RenewTun(fd int) error {
|
||||
if c.connectClient == nil {
|
||||
return fmt.Errorf("engine not running")
|
||||
}
|
||||
|
||||
e := c.connectClient.Engine()
|
||||
if e == nil {
|
||||
return fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
return e.RenewTun(fd)
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
@@ -206,7 +177,6 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
p.ConnStatus.String(),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
}
|
||||
@@ -231,43 +201,31 @@ func (c *Client) Networks() *NetworkArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
log.Error("could not get route selector")
|
||||
return nil
|
||||
}
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||
netStr := r.Network.String()
|
||||
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: routeSelector.IsSelected(id),
|
||||
Domains: domains,
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
@@ -295,69 +253,6 @@ func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
func (c *Client) toggleRoute(command routeCommand) error {
|
||||
return command.toggleRoute()
|
||||
}
|
||||
|
||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||
client := c.connectClient
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := client.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine is not running")
|
||||
}
|
||||
|
||||
manager := engine.GetRouteManager()
|
||||
if manager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(route string) error {
|
||||
manager, err := c.getRouteManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
|
||||
}
|
||||
|
||||
func (c *Client) DeselectRoute(route string) error {
|
||||
manager, err := c.getRouteManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
|
||||
}
|
||||
|
||||
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
|
||||
// with its resolved IP addresses from the provided resolvedDomains map.
|
||||
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
|
||||
domains := NetworkDomains{}
|
||||
|
||||
for _, d := range route.Domains {
|
||||
networkDomain := NetworkDomain{
|
||||
Address: d.SafeString(),
|
||||
}
|
||||
|
||||
if info, exists := resolvedDomains[d]; exists {
|
||||
for _, prefix := range info.Prefixes {
|
||||
networkDomain.addResolvedIP(prefix.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
domains.Add(&networkDomain)
|
||||
}
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
func exportEnvList(list *EnvList) {
|
||||
if list == nil {
|
||||
return
|
||||
|
||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(url string, userCode string)
|
||||
Open(string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
}
|
||||
|
||||
// Login try register the client on the server
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
go func() {
|
||||
err := a.login(urlOpener, isAndroidTV)
|
||||
err := a.login(urlOpener)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
func (a *Auth) login(urlOpener URLOpener) error {
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ResolvedIPs struct {
|
||||
resolvedIPs []string
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Get(i int) (string, error) {
|
||||
if i < 0 || i >= len(r.resolvedIPs) {
|
||||
return "", fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return r.resolvedIPs[i], nil
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Size() int {
|
||||
return len(r.resolvedIPs)
|
||||
}
|
||||
|
||||
type NetworkDomain struct {
|
||||
Address string
|
||||
resolvedIPs ResolvedIPs
|
||||
}
|
||||
|
||||
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
|
||||
d.resolvedIPs.Add(resolvedIP)
|
||||
}
|
||||
|
||||
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
|
||||
return &d.resolvedIPs
|
||||
}
|
||||
|
||||
type NetworkDomains struct {
|
||||
domains []*NetworkDomain
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Add(domain *NetworkDomain) {
|
||||
n.domains = append(n.domains, domain)
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
|
||||
if i < 0 || i >= len(n.domains) {
|
||||
return nil, fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return n.domains[i], nil
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Size() int {
|
||||
return len(n.domains)
|
||||
}
|
||||
@@ -3,16 +3,10 @@
|
||||
package android
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
IsSelected bool
|
||||
Domains NetworkDomains
|
||||
}
|
||||
|
||||
func (n Network) GetNetworkDomains() *NetworkDomains {
|
||||
return &n.Domains
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
}
|
||||
|
||||
type NetworkArray struct {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
@@ -7,11 +5,6 @@ type PeerInfo struct {
|
||||
IP string
|
||||
FQDN string
|
||||
ConnStatus string // Todo replace to enum
|
||||
Routes PeerRoutes
|
||||
}
|
||||
|
||||
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
|
||||
return &p.Routes
|
||||
}
|
||||
|
||||
// PeerInfoArray is a wrapper of []PeerInfo
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PeerRoutes struct {
|
||||
routes []string
|
||||
}
|
||||
|
||||
func (p *PeerRoutes) Get(i int) (string, error) {
|
||||
if i < 0 || i >= len(p.routes) {
|
||||
return "", fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return p.routes[i], nil
|
||||
}
|
||||
|
||||
func (p *PeerRoutes) Size() int {
|
||||
return len(p.routes)
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
|
||||
// at their default locations due to android OS restrictions.
|
||||
type PlatformFiles interface {
|
||||
ConfigurationFilePath() string
|
||||
StateFilePath() string
|
||||
}
|
||||
@@ -1,257 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
// Android-specific config filename (different from desktop default.json)
|
||||
defaultConfigFilename = "netbird.cfg"
|
||||
// Subdirectory for non-default profiles (must match Java Preferences.java)
|
||||
profilesSubdir = "profiles"
|
||||
// Android uses a single user context per app (non-empty username required by ServiceManager)
|
||||
androidUsername = "android"
|
||||
)
|
||||
|
||||
// Profile represents a profile for gomobile
|
||||
type Profile struct {
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
// ProfileArray wraps profiles for gomobile compatibility
|
||||
type ProfileArray struct {
|
||||
items []*Profile
|
||||
}
|
||||
|
||||
// Length returns the number of profiles
|
||||
func (p *ProfileArray) Length() int {
|
||||
return len(p.items)
|
||||
}
|
||||
|
||||
// Get returns the profile at index i
|
||||
func (p *ProfileArray) Get(i int) *Profile {
|
||||
if i < 0 || i >= len(p.items) {
|
||||
return nil
|
||||
}
|
||||
return p.items[i]
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
/data/data/io.netbird.client/files/ ← configDir parameter
|
||||
├── netbird.cfg ← Default profile config
|
||||
├── state.json ← Default profile state
|
||||
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
|
||||
└── profiles/ ← Subdirectory for non-default profiles
|
||||
├── work.json ← Work profile config
|
||||
├── work.state.json ← Work profile state
|
||||
├── personal.json ← Personal profile config
|
||||
└── personal.state.json ← Personal profile state
|
||||
*/
|
||||
|
||||
// ProfileManager manages profiles for Android
|
||||
// It wraps the internal profilemanager to provide Android-specific behavior
|
||||
type ProfileManager struct {
|
||||
configDir string
|
||||
serviceMgr *profilemanager.ServiceManager
|
||||
}
|
||||
|
||||
// NewProfileManager creates a new profile manager for Android
|
||||
func NewProfileManager(configDir string) *ProfileManager {
|
||||
// Set the default config path for Android (stored in root configDir, not profiles/)
|
||||
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
|
||||
|
||||
// Set global paths for Android
|
||||
profilemanager.DefaultConfigPathDir = configDir
|
||||
profilemanager.DefaultConfigPath = defaultConfigPath
|
||||
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
|
||||
|
||||
// Create ServiceManager with profiles/ subdirectory
|
||||
// This avoids modifying the global ConfigDirOverride for profile listing
|
||||
profilesDir := filepath.Join(configDir, profilesSubdir)
|
||||
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
|
||||
|
||||
return &ProfileManager{
|
||||
configDir: configDir,
|
||||
serviceMgr: serviceMgr,
|
||||
}
|
||||
}
|
||||
|
||||
// ListProfiles returns all available profiles
|
||||
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
|
||||
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list profiles: %w", err)
|
||||
}
|
||||
|
||||
// Convert internal profiles to Android Profile type
|
||||
var profiles []*Profile
|
||||
for _, p := range internalProfiles {
|
||||
profiles = append(profiles, &Profile{
|
||||
Name: p.Name,
|
||||
IsActive: p.IsActive,
|
||||
})
|
||||
}
|
||||
|
||||
return &ProfileArray{items: profiles}, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the currently active profile name
|
||||
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return activeState.Name, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches to a different profile
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
Username: androidUsername,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("switched to profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddProfile creates a new profile
|
||||
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||
// Use ServiceManager (creates profile in profiles/ directory)
|
||||
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to add profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("created new profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutProfile logs out from a profile (clears authentication)
|
||||
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
configPath, err := pm.getProfileConfigPath(profileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if profile exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||
}
|
||||
|
||||
// Read current config using internal profilemanager
|
||||
config, err := profilemanager.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read profile config: %w", err)
|
||||
}
|
||||
|
||||
// Clear authentication by removing private key and SSH key
|
||||
config.PrivateKey = ""
|
||||
config.SSHKey = ""
|
||||
|
||||
// Save config using internal profilemanager
|
||||
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("logged out from profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfile deletes a profile
|
||||
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||
// Use ServiceManager (removes profile from profiles/ directory)
|
||||
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("removed profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProfileConfigPath returns the config file path for a profile
|
||||
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
// Android uses netbird.cfg for default profile instead of default.json
|
||||
// Default profile is stored in root configDir, not in profiles/
|
||||
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||
}
|
||||
|
||||
// Non-default profiles are stored in profiles subdirectory
|
||||
// This matches the Java Preferences.java expectation
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||
}
|
||||
|
||||
// GetConfigPath returns the config file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||
return pm.getProfileConfigPath(profileName)
|
||||
}
|
||||
|
||||
// GetStateFilePath returns the state file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
return filepath.Join(pm.configDir, "state.json"), nil
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||
}
|
||||
|
||||
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
|
||||
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||
activeProfile, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetConfigPath(activeProfile)
|
||||
}
|
||||
|
||||
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||
activeProfile, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetStateFilePath(activeProfile)
|
||||
}
|
||||
|
||||
// sanitizeProfileName removes invalid characters from profile name
|
||||
func sanitizeProfileName(name string) string {
|
||||
// Keep only alphanumeric, underscore, and hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func executeRouteToggle(id string, manager routemanager.Manager,
|
||||
operationName string,
|
||||
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
|
||||
netID := route.NetID(id)
|
||||
routes := []route.NetID{netID}
|
||||
|
||||
log.Debugf("%s with id: %s", operationName, id)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("error when %s: %s", operationName, err)
|
||||
return fmt.Errorf("error %s: %w", operationName, err)
|
||||
}
|
||||
|
||||
manager.TriggerSelection(manager.GetClientRoutes())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type routeCommand interface {
|
||||
toggleRoute() error
|
||||
}
|
||||
|
||||
type selectRouteCommand struct {
|
||||
route string
|
||||
manager routemanager.Manager
|
||||
}
|
||||
|
||||
func (s selectRouteCommand) toggleRoute() error {
|
||||
routeSelector := s.manager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return fmt.Errorf("no route selector available")
|
||||
}
|
||||
|
||||
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
|
||||
return routeSelector.SelectRoutes(routes, true, allRoutes)
|
||||
}
|
||||
|
||||
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
|
||||
}
|
||||
|
||||
type deselectRouteCommand struct {
|
||||
route string
|
||||
manager routemanager.Manager
|
||||
}
|
||||
|
||||
func (d deselectRouteCommand) toggleRoute() error {
|
||||
routeSelector := d.manager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return fmt.Errorf("no route selector available")
|
||||
}
|
||||
|
||||
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
|
||||
}
|
||||
@@ -136,7 +136,6 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
//nolint
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
|
||||
@@ -314,8 +313,9 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||
statusOutputString = overview.FullDetailSummary()
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
@@ -81,7 +83,6 @@ var loginCmd = &cobra.Command{
|
||||
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -105,13 +106,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -207,7 +201,6 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -248,7 +241,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -276,7 +269,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -293,7 +286,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -322,17 +315,8 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -373,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||
if err := openBrowser(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||
func openBrowser(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
return open.Run(url)
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
|
||||
@@ -85,9 +85,6 @@ var (
|
||||
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
}
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
@@ -390,7 +387,6 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -259,7 +259,6 @@ func isServiceRunning() (bool, error) {
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
@@ -274,16 +273,12 @@ ManageForeignRoutingPolicyRules=no
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
parentDir := filepath.Dir(networkdConfDir)
|
||||
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
||||
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
bundlePubKeysRootPrivKeyFile string
|
||||
bundlePubKeysPubKeyFiles []string
|
||||
bundlePubKeysFile string
|
||||
|
||||
createArtifactKeyRootPrivKeyFile string
|
||||
createArtifactKeyPrivKeyFile string
|
||||
createArtifactKeyPubKeyFile string
|
||||
createArtifactKeyExpiration time.Duration
|
||||
)
|
||||
|
||||
var createArtifactKeyCmd = &cobra.Command{
|
||||
Use: "create-artifact-key",
|
||||
Short: "Create a new artifact signing key",
|
||||
Long: `Generate a new artifact signing key pair signed by the root private key.
|
||||
The artifact key will be used to sign software artifacts/updates.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if createArtifactKeyExpiration <= 0 {
|
||||
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||
}
|
||||
|
||||
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
|
||||
return fmt.Errorf("failed to create artifact key: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var bundlePubKeysCmd = &cobra.Command{
|
||||
Use: "bundle-pub-keys",
|
||||
Short: "Bundle multiple artifact public keys into a signed package",
|
||||
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
|
||||
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(bundlePubKeysPubKeyFiles) == 0 {
|
||||
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
|
||||
}
|
||||
|
||||
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
|
||||
return fmt.Errorf("failed to bundle public keys: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(createArtifactKeyCmd)
|
||||
|
||||
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
|
||||
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
|
||||
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
|
||||
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
|
||||
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||
}
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
|
||||
}
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||
}
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||
panic(fmt.Errorf("mark expiration as required: %w", err))
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(bundlePubKeysCmd)
|
||||
|
||||
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
|
||||
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
|
||||
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
|
||||
|
||||
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||
}
|
||||
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||
}
|
||||
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
|
||||
cmd.Println("Creating new artifact signing key...")
|
||||
|
||||
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read root private key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate artifact key: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
|
||||
}
|
||||
|
||||
signatureFile := artifactPubKeyFile + ".sig"
|
||||
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("✅ Artifact key created successfully.\n")
|
||||
cmd.Printf("%s\n", artifactKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
|
||||
cmd.Println("📦 Bundling public keys into signed package...")
|
||||
|
||||
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read root private key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
|
||||
for _, pubFile := range artifactPubKeyFiles {
|
||||
pubPem, err := os.ReadFile(pubFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read public key file: %w", err)
|
||||
}
|
||||
|
||||
pk, err := reposign.ParseArtifactPubKey(pubPem)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse artifact key: %w", err)
|
||||
}
|
||||
publicKeys = append(publicKeys, pk)
|
||||
}
|
||||
|
||||
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bundle artifact keys: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
|
||||
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
|
||||
}
|
||||
|
||||
signatureFile := bundlePubKeysFile + ".sig"
|
||||
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
|
||||
return nil
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
|
||||
)
|
||||
|
||||
var (
|
||||
signArtifactPrivKeyFile string
|
||||
signArtifactArtifactFile string
|
||||
|
||||
verifyArtifactPubKeyFile string
|
||||
verifyArtifactFile string
|
||||
verifyArtifactSignatureFile string
|
||||
|
||||
verifyArtifactKeyPubKeyFile string
|
||||
verifyArtifactKeyRootPubKeyFile string
|
||||
verifyArtifactKeySignatureFile string
|
||||
verifyArtifactKeyRevocationFile string
|
||||
)
|
||||
|
||||
var signArtifactCmd = &cobra.Command{
|
||||
Use: "sign-artifact",
|
||||
Short: "Sign an artifact using an artifact private key",
|
||||
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
|
||||
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
|
||||
return fmt.Errorf("failed to sign artifact: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var verifyArtifactCmd = &cobra.Command{
|
||||
Use: "verify-artifact",
|
||||
Short: "Verify an artifact signature using an artifact public key",
|
||||
Long: `Verify a software artifact signature using the artifact's public key.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
|
||||
return fmt.Errorf("failed to verify artifact: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var verifyArtifactKeyCmd = &cobra.Command{
|
||||
Use: "verify-artifact-key",
|
||||
Short: "Verify an artifact public key was signed by a root key",
|
||||
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
|
||||
This validates the chain of trust from the root key to the artifact key.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
|
||||
return fmt.Errorf("failed to verify artifact key: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(signArtifactCmd)
|
||||
rootCmd.AddCommand(verifyArtifactCmd)
|
||||
rootCmd.AddCommand(verifyArtifactKeyCmd)
|
||||
|
||||
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
|
||||
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
|
||||
|
||||
// artifact-file is required, but artifact-key-file can come from env var
|
||||
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||
}
|
||||
|
||||
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
|
||||
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
|
||||
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
|
||||
|
||||
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||
}
|
||||
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
|
||||
|
||||
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark root-key-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
|
||||
cmd.Println("🖋️ Signing artifact...")
|
||||
|
||||
// Load private key from env var or file
|
||||
var privKeyPEM []byte
|
||||
var err error
|
||||
|
||||
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
|
||||
// Use key from environment variable
|
||||
privKeyPEM = []byte(envKey)
|
||||
} else if privKeyFile != "" {
|
||||
// Fall back to file
|
||||
privKeyPEM, err = os.ReadFile(privKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read private key file: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
|
||||
}
|
||||
|
||||
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse artifact private key: %w", err)
|
||||
}
|
||||
|
||||
artifactData, err := os.ReadFile(artifactFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read artifact file: %w", err)
|
||||
}
|
||||
|
||||
signature, err := reposign.SignData(privateKey, artifactData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign artifact: %w", err)
|
||||
}
|
||||
|
||||
sigFile := artifactFile + ".sig"
|
||||
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
|
||||
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("✅ Artifact signed successfully.\n")
|
||||
cmd.Printf("Signature file: %s\n", sigFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
|
||||
cmd.Println("🔍 Verifying artifact...")
|
||||
|
||||
// Read artifact public key
|
||||
pubKeyPEM, err := os.ReadFile(pubKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read public key file: %w", err)
|
||||
}
|
||||
|
||||
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse artifact public key: %w", err)
|
||||
}
|
||||
|
||||
// Read artifact data
|
||||
artifactData, err := os.ReadFile(artifactFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read artifact file: %w", err)
|
||||
}
|
||||
|
||||
// Read signature
|
||||
sigBytes, err := os.ReadFile(signatureFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read signature file: %w", err)
|
||||
}
|
||||
|
||||
signature, err := reposign.ParseSignature(sigBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse signature: %w", err)
|
||||
}
|
||||
|
||||
// Validate artifact
|
||||
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
|
||||
return fmt.Errorf("artifact verification failed: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Artifact signature is valid")
|
||||
cmd.Printf("Artifact: %s\n", artifactFile)
|
||||
cmd.Printf("Signed by key: %s\n", signature.KeyID)
|
||||
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
|
||||
cmd.Println("🔍 Verifying artifact key...")
|
||||
|
||||
// Read artifact key data
|
||||
artifactKeyData, err := os.ReadFile(artifactKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read artifact key file: %w", err)
|
||||
}
|
||||
|
||||
// Read root public key(s)
|
||||
rootKeyData, err := os.ReadFile(rootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read root key file: %w", err)
|
||||
}
|
||||
|
||||
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse root public key(s): %w", err)
|
||||
}
|
||||
|
||||
// Read signature
|
||||
sigBytes, err := os.ReadFile(signatureFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read signature file: %w", err)
|
||||
}
|
||||
|
||||
signature, err := reposign.ParseSignature(sigBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse signature: %w", err)
|
||||
}
|
||||
|
||||
// Read optional revocation list
|
||||
var revocationList *reposign.RevocationList
|
||||
if revocationFile != "" {
|
||||
revData, err := os.ReadFile(revocationFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read revocation file: %w", err)
|
||||
}
|
||||
|
||||
revocationList, err = reposign.ParseRevocationList(revData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate artifact key(s)
|
||||
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
|
||||
if err != nil {
|
||||
return fmt.Errorf("artifact key verification failed: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Artifact key(s) verified successfully")
|
||||
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
|
||||
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
|
||||
for i, key := range validKeys {
|
||||
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
|
||||
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
|
||||
if !key.Metadata.ExpiresAt.IsZero() {
|
||||
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
|
||||
} else {
|
||||
cmd.Printf(" Expires: Never\n")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseRootPublicKeys parses a root public key from PEM data
|
||||
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
|
||||
key, err := reposign.ParseRootPublicKey(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []reposign.PublicKey{key}, nil
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "signer",
|
||||
Short: "A CLI tool for managing cryptographic keys and artifacts",
|
||||
Long: `signer is a command-line tool that helps you manage
|
||||
root keys, artifact keys, and revocation lists securely.`,
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
rootCmd.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -1,220 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
|
||||
)
|
||||
|
||||
var (
|
||||
keyID string
|
||||
revocationListFile string
|
||||
privateRootKeyFile string
|
||||
publicRootKeyFile string
|
||||
signatureFile string
|
||||
expirationDuration time.Duration
|
||||
)
|
||||
|
||||
var createRevocationListCmd = &cobra.Command{
|
||||
Use: "create-revocation-list",
|
||||
Short: "Create a new revocation list signed by the private root key",
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
|
||||
},
|
||||
}
|
||||
|
||||
var extendRevocationListCmd = &cobra.Command{
|
||||
Use: "extend-revocation-list",
|
||||
Short: "Extend an existing revocation list with a given key ID",
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
|
||||
},
|
||||
}
|
||||
|
||||
var verifyRevocationListCmd = &cobra.Command{
|
||||
Use: "verify-revocation-list",
|
||||
Short: "Verify a revocation list signature using the public root key",
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(createRevocationListCmd)
|
||||
rootCmd.AddCommand(extendRevocationListCmd)
|
||||
rootCmd.AddCommand(verifyRevocationListCmd)
|
||||
|
||||
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
|
||||
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
|
||||
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
|
||||
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
|
||||
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
|
||||
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create revocation list: %w", err)
|
||||
}
|
||||
|
||||
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output files: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Revocation list created successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
|
||||
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
rlBytes, err := os.ReadFile(revocationListFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||
}
|
||||
|
||||
rl, err := reposign.ParseRevocationList(rlBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||
}
|
||||
|
||||
kid, err := reposign.ParseKeyID(keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid key ID: %w", err)
|
||||
}
|
||||
|
||||
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extend revocation list: %w", err)
|
||||
}
|
||||
|
||||
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output files: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Revocation list extended successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
|
||||
// Read revocation list file
|
||||
rlBytes, err := os.ReadFile(revocationListFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||
}
|
||||
|
||||
// Read signature file
|
||||
sigBytes, err := os.ReadFile(signatureFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read signature file: %w", err)
|
||||
}
|
||||
|
||||
// Read public root key file
|
||||
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read public root key file: %w", err)
|
||||
}
|
||||
|
||||
// Parse public root key
|
||||
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public root key: %w", err)
|
||||
}
|
||||
|
||||
// Parse signature
|
||||
signature, err := reposign.ParseSignature(sigBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse signature: %w", err)
|
||||
}
|
||||
|
||||
// Validate revocation list
|
||||
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate revocation list: %w", err)
|
||||
}
|
||||
|
||||
// Display results
|
||||
cmd.Println("✅ Revocation list signature is valid")
|
||||
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
|
||||
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
|
||||
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
|
||||
|
||||
if len(rl.Revoked) > 0 {
|
||||
cmd.Println("\nRevoked Keys:")
|
||||
for keyID, revokedTime := range rl.Revoked {
|
||||
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
|
||||
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
|
||||
return fmt.Errorf("failed to write revocation list file: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
|
||||
return fmt.Errorf("failed to write signature file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
privKeyFile string
|
||||
pubKeyFile string
|
||||
rootExpiration time.Duration
|
||||
)
|
||||
|
||||
var createRootKeyCmd = &cobra.Command{
|
||||
Use: "create-root-key",
|
||||
Short: "Create a new root key pair",
|
||||
Long: `Create a new root key pair and specify an expiration time for it.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Validate expiration
|
||||
if rootExpiration <= 0 {
|
||||
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||
}
|
||||
|
||||
// Run main logic
|
||||
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
|
||||
return fmt.Errorf("failed to generate root key: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(createRootKeyCmd)
|
||||
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
|
||||
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
|
||||
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
|
||||
|
||||
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
|
||||
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate root key: %w", err)
|
||||
}
|
||||
|
||||
// Write private key
|
||||
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
|
||||
}
|
||||
|
||||
// Write public key
|
||||
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("%s\n\n", rk.String())
|
||||
cmd.Printf("✅ Root key pair generated successfully.\n")
|
||||
return nil
|
||||
}
|
||||
@@ -14,9 +14,7 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
@@ -36,7 +34,6 @@ const (
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -50,8 +47,6 @@ var (
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
requestPTY bool
|
||||
sshNoBrowser bool
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,7 +56,6 @@ var (
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -71,18 +65,14 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
@@ -107,9 +97,9 @@ SSH Options:
|
||||
-p, --port int Remote SSH port (default 22)
|
||||
-u, --user string SSH username
|
||||
--login string SSH username (alias for --user)
|
||||
-t, --tty Force pseudo-terminal allocation
|
||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||
-o, --known-hosts string Path to known_hosts file
|
||||
-i, --identity string Path to SSH private key file
|
||||
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
@@ -117,10 +107,8 @@ Examples:
|
||||
netbird ssh --login root peer-hostname
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami
|
||||
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||
DisableFlagParsing: true,
|
||||
@@ -155,10 +143,10 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||
errCh <- err
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
@@ -166,10 +154,6 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-sshctx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
@@ -187,21 +171,6 @@ func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
@@ -213,7 +182,6 @@ func resetSSHGlobals() {
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
sshNoBrowser = false
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
@@ -383,12 +351,10 @@ type sshFlags struct {
|
||||
Port int
|
||||
Username string
|
||||
Login string
|
||||
RequestPTY bool
|
||||
StrictHostKeyChecking bool
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
NoBrowser bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
@@ -400,7 +366,6 @@ type sshFlags struct {
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
@@ -408,25 +373,22 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
flags := &sshFlags{}
|
||||
|
||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||
fs.String("user", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||
|
||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||
fs.String("known-hosts", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||
fs.String("identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||
|
||||
return fs, flags
|
||||
}
|
||||
@@ -447,10 +409,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
fs, flags := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
if errors.Is(err, flag.ErrHelp) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
return parseHostnameAndCommand(filteredArgs)
|
||||
}
|
||||
|
||||
remaining := fs.Args()
|
||||
@@ -465,12 +424,10 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
username = flags.Login
|
||||
}
|
||||
|
||||
requestPTY = flags.RequestPTY
|
||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
sshNoBrowser = flags.NoBrowser
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
@@ -530,7 +487,6 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
NoBrowser: sshNoBrowser,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -564,29 +520,10 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
|
||||
// executeSSHCommand executes a command over SSH.
|
||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||
var err error
|
||||
if requestPTY {
|
||||
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||
} else {
|
||||
err = c.ExecuteCommandWithIO(ctx, command)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
os.Exit(exitErr.ExitStatus())
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote command exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -598,13 +535,6 @@ func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("open terminal: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -634,11 +564,7 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateDestinationPort(remoteAddr); err != nil {
|
||||
return fmt.Errorf("invalid remote address: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
@@ -656,11 +582,7 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateDestinationPort(localAddr); err != nil {
|
||||
return fmt.Errorf("invalid local address: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
@@ -671,35 +593,6 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDestinationPort checks that the destination address has a valid port.
|
||||
// Port 0 is only valid for bind addresses (where the OS picks an available port),
|
||||
// not for destination addresses where we need to connect.
|
||||
func validateDestinationPort(addr string) error {
|
||||
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
return fmt.Errorf("port 0 is not valid for destination address")
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return fmt.Errorf("port %d out of range (1-65535)", port)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
@@ -809,9 +702,7 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
|
||||
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
@@ -823,23 +714,10 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||
// where command-line flags cannot be passed. Default is to open browser.
|
||||
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||
var browserOpener func(string) error
|
||||
if !noBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.Close(); err != nil {
|
||||
log.Debugf("close SSH proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("SSH proxy: %w", err)
|
||||
@@ -858,8 +736,7 @@ var sshDetectCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
if err := util.InitLog(detectLogLevel, "console"); err != nil {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
@@ -868,21 +745,15 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Debugf("invalid port %q: %v", portStr, err)
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("SSH server detection failed: %v", err)
|
||||
cancel()
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
cancel()
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -52,7 +51,7 @@ func sftpMainDirect(cmd *cobra.Command) error {
|
||||
if windowsDomain != "" {
|
||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||
}
|
||||
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||
if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
|
||||
cmd.PrintErrf("user switching failed\n")
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
|
||||
@@ -667,51 +667,3 @@ func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "invalid long flag before hostname",
|
||||
args: []string{"--invalid-flag", "hostname"},
|
||||
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||
},
|
||||
{
|
||||
name: "invalid short flag before hostname",
|
||||
args: []string{"-x", "hostname"},
|
||||
description: "Invalid short flag should return parse error",
|
||||
},
|
||||
{
|
||||
name: "invalid flag with value before hostname",
|
||||
args: []string{"--invalid-option=value", "hostname"},
|
||||
description: "Invalid flag with value should return parse error",
|
||||
},
|
||||
{
|
||||
name: "typo in known flag",
|
||||
args: []string{"--por", "2222", "hostname"},
|
||||
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
|
||||
// Should return an error for invalid flags
|
||||
assert.Error(t, err, tt.description)
|
||||
|
||||
// Should not have set host to the invalid flag
|
||||
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
case jsonFlag:
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
case yamlFlag:
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -124,7 +124,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -13,12 +13,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -26,6 +20,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -88,7 +84,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
@@ -110,21 +110,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
@@ -216,7 +216,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -287,13 +286,6 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -363,18 +355,14 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
@@ -479,10 +467,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -603,11 +587,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var updateCmd *cobra.Command
|
||||
|
||||
func isUpdateBinary() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
updateCmd = &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update the NetBird client application",
|
||||
RunE: updateFunc,
|
||||
}
|
||||
|
||||
tempDirFlag string
|
||||
installerFile string
|
||||
serviceDirFlag string
|
||||
dryRunFlag bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
|
||||
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
|
||||
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
|
||||
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
|
||||
}
|
||||
|
||||
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
|
||||
func isUpdateBinary() bool {
|
||||
// Remove extension for cross-platform compatibility
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
baseName := filepath.Base(execPath)
|
||||
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||
|
||||
return name == installer.UpdaterBinaryNameWithoutExtension()
|
||||
}
|
||||
|
||||
func updateFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupLogToFile(tempDirFlag); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("updater started: %s", serviceDirFlag)
|
||||
updater := installer.NewWithDir(tempDirFlag)
|
||||
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
|
||||
log.Errorf("failed to update application: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupLogToFile(dir string) error {
|
||||
logFile := filepath.Join(dir, installer.LogFile)
|
||||
|
||||
if _, err := os.Stat(logFile); err == nil {
|
||||
if err := os.Remove(logFile); err != nil {
|
||||
log.Errorf("failed to remove existing log file: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return util.InitLog(logLevel, util.LogConsole, logFile)
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -31,11 +29,6 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -45,7 +38,6 @@ type Client struct {
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
@@ -169,17 +161,11 @@ func New(opts Options) (*Client, error) {
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connect != nil {
|
||||
if c.cancel != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||
defer func() {
|
||||
if c.connect == nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
@@ -187,9 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
@@ -213,7 +197,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
c.cancel = cancel
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -228,23 +211,17 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
connect := c.connect
|
||||
go func() {
|
||||
done <- connect.Stop()
|
||||
done <- c.connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.connect = nil
|
||||
c.cancel = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.connect = nil
|
||||
c.cancel = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
@@ -264,40 +241,18 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
// With lazy connections, the connection is established on first traffic.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
|
||||
|
||||
// Check context status upfront
|
||||
if ctx.Err() != nil {
|
||||
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when DialContext is called. The netstack
|
||||
// dial triggers WireGuard traffic which activates the lazy connection.
|
||||
|
||||
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
|
||||
conn, err := nsnet.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("embed.Dial: successfully connected to %s", address)
|
||||
return conn, nil
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
@@ -360,90 +315,6 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
return recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncResp, err := engine.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get sync response: %w", err)
|
||||
}
|
||||
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// WaitForPeerConnection waits for a peer with the given IP to be connected.
|
||||
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
|
||||
logrus.Infof("Waiting for peer %s to be connected", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
|
||||
case <-ticker.C:
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
logrus.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range status.Peers {
|
||||
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
|
||||
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
// Note: ConnectClient doesn't have SetLogLevel method
|
||||
_ = connect
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -41,13 +40,19 @@ type aclManager struct {
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
return &aclManager{
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
@@ -93,8 +98,8 @@ func (m *aclManager) AddPeerFiltering(
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
@@ -108,18 +113,14 @@ func (m *aclManager) AddPeerFiltering(
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||
}
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
@@ -171,16 +172,11 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
@@ -194,7 +190,10 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
shouldDestroyIpset = true
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
@@ -207,16 +206,6 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
@@ -275,19 +264,11 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
@@ -386,8 +367,11 @@ func (m *aclManager) updateState() {
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
@@ -432,61 +416,3 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
t.Logf(" [%d] %s", i, rule)
|
||||
}
|
||||
|
||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||
var denyRuleIndex, acceptRuleIndex int = -1, -1
|
||||
for i, rule := range rules {
|
||||
if strings.Contains(rule, "DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
@@ -107,6 +107,10 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -228,12 +232,12 @@ func (r *router) findSets(rule []string) []string {
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
@@ -242,7 +246,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
@@ -911,8 +915,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
@@ -989,37 +993,3 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||
|
||||
// Find the accept and deny rules and verify deny comes before accept
|
||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||
var acceptRuleIndex, denyRuleIndex int = -1, -1
|
||||
for i, rule := range rules {
|
||||
hasAcceptHTTPSet := false
|
||||
hasDenyHTTPSet := false
|
||||
@@ -208,13 +208,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
for _, e := range rule.Exprs {
|
||||
// Check for set lookup
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
switch lookup.SetName {
|
||||
case "accept-http":
|
||||
if lookup.SetName == "accept-http" {
|
||||
hasAcceptHTTPSet = true
|
||||
case "deny-http":
|
||||
} else if lookup.SetName == "deny-http" {
|
||||
hasDenyHTTPSet = true
|
||||
}
|
||||
|
||||
}
|
||||
// Check for port 80
|
||||
if cmp, ok := e.(*expr.Cmp); ok {
|
||||
@@ -224,10 +222,9 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
}
|
||||
// Check for verdict
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch verdict.Kind {
|
||||
case expr.VerdictAccept:
|
||||
if verdict.Kind == expr.VerdictAccept {
|
||||
action = "ACCEPT"
|
||||
case expr.VerdictDrop:
|
||||
} else if verdict.Kind == expr.VerdictDrop {
|
||||
action = "DROP"
|
||||
}
|
||||
}
|
||||
@@ -389,97 +386,6 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
const octet2Count = 25
|
||||
const octet3Count = 255
|
||||
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
|
||||
for i := 1; i < octet2Count; i++ {
|
||||
for j := 1; j < octet3Count; j++ {
|
||||
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
|
||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||
}
|
||||
}
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
prefixes,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||
t.Helper()
|
||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||
|
||||
@@ -27,11 +27,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableRaw = "raw"
|
||||
tableSecurity = "security"
|
||||
|
||||
tableNat = "nat"
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
@@ -48,12 +44,10 @@ const (
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
refreshRulesMapError = "refresh rules map: %w"
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
)
|
||||
@@ -97,7 +91,11 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
log.Debugf("ip filter table not found: %v", err)
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, fmt.Errorf("load filter table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
@@ -177,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
@@ -189,39 +187,14 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func hookName(hook *nftables.ChainHook) string {
|
||||
if hook == nil {
|
||||
return "unknown"
|
||||
}
|
||||
switch *hook {
|
||||
case *nftables.ChainHookForward:
|
||||
return chainNameForward
|
||||
case *nftables.ChainHookInput:
|
||||
return chainNameInput
|
||||
default:
|
||||
return fmt.Sprintf("hook(%d)", *hook)
|
||||
}
|
||||
}
|
||||
|
||||
func familyName(family nftables.TableFamily) string {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4:
|
||||
return "ip"
|
||||
case nftables.TableFamilyIPv6:
|
||||
return "ip6"
|
||||
case nftables.TableFamilyINet:
|
||||
return "inet"
|
||||
default:
|
||||
return fmt.Sprintf("family(%d)", family)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
@@ -263,12 +236,9 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.addPostroutingRules()
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
@@ -280,7 +250,11 @@ func (r *router) createContainers() error {
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -515,35 +489,16 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
initialElements := elements[:min(maxElements, nElements)]
|
||||
|
||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||
|
||||
var subEnd int
|
||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||
subEnd = min(subStart+maxElements, nElements)
|
||||
subElement := elements[subStart:subEnd]
|
||||
nSubPrefixes := len(subElement) / 2
|
||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
}
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
@@ -740,7 +695,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// addPostroutingRules adds the masquerade rules
|
||||
func (r *router) addPostroutingRules() {
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
@@ -806,6 +761,8 @@ func (r *router) addPostroutingRules() {
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
@@ -882,7 +839,7 @@ func (r *router) addMSSClampingRules() error {
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return r.conn.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
@@ -982,21 +939,8 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.acceptFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1009,11 +953,11 @@ func (r *router) acceptFilterTableRules() error {
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// iptables is not available but the filter table exists
|
||||
// filter table exists but iptables is not
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
return r.acceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
@@ -1024,7 +968,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
@@ -1032,7 +976,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
@@ -1052,70 +996,18 @@ func (r *router) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||
// This is used when iptables is not available.
|
||||
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||
func (r *router) acceptFilterRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
forwardChain := &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertForwardAcceptRules(forwardChain, intf)
|
||||
|
||||
inputChain := &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertInputAcceptRule(inputChain, intf)
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||
func (r *router) acceptExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
iifRule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -1138,19 +1030,30 @@ func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
oifRule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
r.conn.InsertRule(oifRule)
|
||||
}
|
||||
|
||||
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
inputRule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -1164,44 +1067,32 @@ func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
}
|
||||
r.conn.InsertRule(inputRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != table.Name {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1209,101 +1100,27 @@ func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||
// ensuring cleanup works even after a crash or if chains changed.
|
||||
func (r *router) removeExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||
func (r *router) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
log.Debugf("list chains for family %d: %v", family, err)
|
||||
continue
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range allChains {
|
||||
if r.isExternalChain(chain) {
|
||||
chains = append(chains, chain)
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chains
|
||||
}
|
||||
|
||||
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||
return false
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Type != nftables.ChainTypeFilter {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Hooknum == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||
}
|
||||
|
||||
func isIptablesTable(name string) bool {
|
||||
switch name {
|
||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
@@ -1311,13 +1128,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -1379,7 +1196,7 @@ func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list rules: %w", err)
|
||||
return fmt.Errorf(" unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
layerTypeAll = 255
|
||||
layerTypeAll = 0
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
@@ -262,7 +262,10 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix := iface.Address().Network
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
rule, err := m.addRouteFiltering(
|
||||
@@ -436,7 +439,19 @@ func (m *Manager) AddPeerFiltering(
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case firewall.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
@@ -481,17 +496,16 @@ func (m *Manager) addRouteFiltering(
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
if destination.IsPrefix() {
|
||||
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||
@@ -781,7 +795,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
var sum = pseudoSum
|
||||
var sum uint32 = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
@@ -931,7 +945,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
|
||||
if blocked {
|
||||
pnum := getProtocolFromPacket(d)
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
@@ -996,22 +1010,20 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return false
|
||||
}
|
||||
|
||||
protoLayer := d.decoded[1]
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: proto,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
@@ -1040,33 +1052,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
return layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
return layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
if ipLayer == layers.LayerTypeIPv6 {
|
||||
return layers.LayerTypeICMPv6
|
||||
}
|
||||
return layers.LayerTypeICMPv4
|
||||
case firewall.ProtocolALL:
|
||||
return layerTypeAll
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return nftypes.TCP
|
||||
return firewall.ProtocolTCP, nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return nftypes.UDP
|
||||
return firewall.ProtocolUDP, nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMP
|
||||
return firewall.ProtocolICMP, nftypes.ICMP
|
||||
default:
|
||||
return nftypes.ProtocolUnknown
|
||||
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1238,30 +1233,19 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
return false
|
||||
}
|
||||
|
||||
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
destMatched := false
|
||||
for _, dst := range rule.destinations {
|
||||
if dst.Contains(dstAddr) {
|
||||
@@ -1280,8 +1264,21 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
return sourceMatched
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
|
||||
@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
for _, tc := range cases {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -17,7 +16,7 @@ type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu atomic.Uint32
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -29,7 +28,7 @@ func (e *endpoint) IsAttached() bool {
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu.Load()
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
@@ -83,22 +82,6 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *endpoint) Close() {
|
||||
// Endpoint cleanup - nothing to do as device is managed externally
|
||||
}
|
||||
|
||||
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
|
||||
// Link address is not used for this endpoint type
|
||||
}
|
||||
|
||||
func (e *endpoint) SetMTU(mtu uint32) {
|
||||
e.mtu.Store(mtu)
|
||||
}
|
||||
|
||||
func (e *endpoint) SetOnCloseAction(func()) {
|
||||
// No action needed on close
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -36,16 +35,14 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -63,8 +60,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
endpoint.mtu.Store(uint32(mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
@@ -106,16 +103,15 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -133,8 +129,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
@@ -204,24 +198,3 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,8 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,95 +14,30 @@ import (
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
// For Echo Requests, send and wait for response
|
||||
if icmpHdr.Type() == header.ICMPv4Echo {
|
||||
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
|
||||
if !f.hasRawICMPAccess {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
|
||||
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
|
||||
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
@@ -113,22 +45,38 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
}
|
||||
}()
|
||||
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu.Load())
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
@@ -137,7 +85,31 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
return 0
|
||||
}
|
||||
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
@@ -180,95 +152,3 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
|
||||
// This is used as a fallback when raw socket access is not available.
|
||||
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
|
||||
icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
|
||||
// Returns the size of the injected packet.
|
||||
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyICMPHdr := header.ICMPv4(replyICMP)
|
||||
replyICMPHdr.SetType(header.ICMPv4EchoReply)
|
||||
replyICMPHdr.SetChecksum(0)
|
||||
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
|
||||
|
||||
return f.injectICMPReply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
|
||||
// Returns the total size of the injected packet, or 0 if injection failed.
|
||||
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -132,10 +131,10 @@ func (f *udpForwarder) cleanup() {
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
@@ -145,7 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
@@ -163,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
@@ -174,10 +173,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(&wq, ep)
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
@@ -200,7 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return true
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
@@ -209,7 +208,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
@@ -350,7 +348,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
|
||||
@@ -130,7 +130,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,15 +168,6 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
|
||||
@@ -234,10 +234,9 @@ func TestInboundPortDNATNegative(t *testing.T) {
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
d = parsePacket(t, packet)
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
case layers.IPProtocolUDP:
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -34,7 +34,7 @@ type RouteRule struct {
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
protoLayer gopacket.LayerType
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
|
||||
@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
protoLayer := d.decoded[1]
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -18,6 +20,9 @@ import (
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||
|
||||
// Backoff returns a backoff configuration for gRPC calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
@@ -26,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||
conn.Connect()
|
||||
|
||||
state := conn.GetState()
|
||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||
if !conn.WaitForStateChange(ctx, state) {
|
||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||
}
|
||||
state = conn.GetState()
|
||||
}
|
||||
|
||||
if state == connectivity.Shutdown {
|
||||
return ErrConnectionShutdown
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
@@ -43,22 +68,25 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial context: %w", err)
|
||||
return nil, fmt.Errorf("new client: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
@@ -27,23 +27,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
||||
}
|
||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
buf := bufs[0]
|
||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = size
|
||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
||||
eps[0] = stdEp
|
||||
return 1, nil
|
||||
}
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -20,12 +19,11 @@ import (
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
type WGTunDevice struct {
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu uint16
|
||||
iceBind *bind.ICEBind
|
||||
// todo: review if we can eliminate the TunAdapter
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu uint16
|
||||
iceBind *bind.ICEBind
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
@@ -34,19 +32,17 @@ type WGTunDevice struct {
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
renewableTun *RenewableTUN
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
renewableTun: NewRenewableTUN(),
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,17 +65,14 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||
|
||||
t.name = name
|
||||
t.filteredDevice = newDeviceFilter(t.renewableTun)
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
@@ -111,23 +104,6 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) RenewTun(fd int) error {
|
||||
if t.device == nil {
|
||||
return fmt.Errorf("device not initialized")
|
||||
}
|
||||
|
||||
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
log.Errorf("failed to renew Android interface: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
//go:build ios
|
||||
// +build ios
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -43,31 +45,10 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
|
||||
// This typically means the Swift code couldn't find the utun control socket.
|
||||
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
|
||||
|
||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
log.Infof("create tun interface")
|
||||
|
||||
var tunDevice tun.Device
|
||||
var err error
|
||||
|
||||
// Validate the tunnel file descriptor.
|
||||
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
|
||||
// A value of 0 means the Swift code couldn't find the utun control socket
|
||||
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
|
||||
// tvOS SDK headers). This is a hard error - there's no viable fallback
|
||||
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
|
||||
if t.tunFd == 0 {
|
||||
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
|
||||
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
|
||||
return nil, ErrInvalidTunnelFD
|
||||
}
|
||||
|
||||
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
|
||||
var dupTunFd int
|
||||
dupTunFd, err = unix.Dup(t.tunFd)
|
||||
dupTunFd, err := unix.Dup(t.tunFd)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
@@ -79,7 +60,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
_ = unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to create new tun device from fd: %v", err)
|
||||
_ = unix.Close(dupTunFd)
|
||||
|
||||
@@ -2,13 +2,6 @@
|
||||
|
||||
package device
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||
return t.create()
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) RenewTun(fd int) error {
|
||||
// Doesn't make sense in Android for Netstack.
|
||||
return fmt.Errorf("this function has not been implemented in Netstack for Android")
|
||||
}
|
||||
|
||||
@@ -1,309 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// closeAwareDevice wraps a tun.Device along with a flag
|
||||
// indicating whether its Close method was called.
|
||||
//
|
||||
// It also redirects tun.Device's Events() to a separate goroutine
|
||||
// and closes it when Close is called.
|
||||
//
|
||||
// The WaitGroup and CloseOnce fields are used to ensure that the
|
||||
// goroutine is awaited and closed only once.
|
||||
type closeAwareDevice struct {
|
||||
isClosed atomic.Bool
|
||||
tun.Device
|
||||
closeEventCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||
return &closeAwareDevice{
|
||||
Device: tunDevice,
|
||||
isClosed: atomic.Bool{},
|
||||
closeEventCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||
// to the given channel (RenewableTUN's events channel).
|
||||
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-c.Device.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ev == tun.EventDown {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- ev:
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close calls the underlying Device's Close method
|
||||
// after setting isClosed to true.
|
||||
func (c *closeAwareDevice) Close() (err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.isClosed.Store(true)
|
||||
close(c.closeEventCh)
|
||||
err = c.Device.Close()
|
||||
c.wg.Wait()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *closeAwareDevice) IsClosed() bool {
|
||||
return c.isClosed.Load()
|
||||
}
|
||||
|
||||
type RenewableTUN struct {
|
||||
devices []*closeAwareDevice
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
events chan tun.Event
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewRenewableTUN() *RenewableTUN {
|
||||
r := &RenewableTUN{
|
||||
devices: make([]*closeAwareDevice, 0),
|
||||
mu: sync.Mutex{},
|
||||
events: make(chan tun.Event, 16),
|
||||
}
|
||||
r.cond = sync.NewCond(&r.mu)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) File() *os.File {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
file := dev.File()
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from an underlying tun.Device kept in the r.devices slice.
|
||||
// If no device is available, it waits for one to be added via AddDevice().
|
||||
//
|
||||
// On error, it retries reading from the newest device instead of returning the error
|
||||
// if the device is closed; if not, it propagates the error.
|
||||
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
// wait until AddDevice() signals a new device via cond.Broadcast()
|
||||
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err = dev.Read(bufs, sizes, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// swap in progress; retry on the newest instead of returning the error
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return n, err // propagate non-swap error
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes to underlying tun.Device kept in the r.devices slice.
|
||||
// If no device is available, it waits for one to be added via AddDevice().
|
||||
//
|
||||
// On error, it retries writing to the newest device instead of returning the error
|
||||
// if the device is closed; if not, it propagates the error.
|
||||
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := dev.Write(bufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) MTU() (int, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
mtu, err := dev.MTU()
|
||||
if err == nil {
|
||||
return mtu, nil
|
||||
}
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) Name() (string, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return "", io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
name, err := dev.Name()
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Events returns a channel that is fed events from the underlying tun.Device's events channel
|
||||
// once it is added.
|
||||
func (r *RenewableTUN) Events() <-chan tun.Event {
|
||||
return r.events
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) Close() error {
|
||||
// Attempts to set the RenewableTUN closed flag to true.
|
||||
// If it's already true, returns immediately.
|
||||
if !r.closed.CompareAndSwap(false, true) {
|
||||
return nil // already closed: idempotent
|
||||
}
|
||||
r.mu.Lock()
|
||||
devices := r.devices
|
||||
r.devices = nil
|
||||
r.cond.Broadcast()
|
||||
r.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
|
||||
log.Debugf("closing %d devices", len(devices))
|
||||
for _, device := range devices {
|
||||
if err := device.Close(); err != nil {
|
||||
log.Debugf("error closing a device: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
close(r.events)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) AddDevice(device tun.Device) {
|
||||
r.mu.Lock()
|
||||
if r.closed.Load() {
|
||||
r.mu.Unlock()
|
||||
_ = device.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var toClose *closeAwareDevice
|
||||
if len(r.devices) > 0 {
|
||||
toClose = r.devices[len(r.devices)-1]
|
||||
}
|
||||
|
||||
cad := newClosableDevice(device)
|
||||
cad.redirectEvents(r.events)
|
||||
|
||||
r.devices = []*closeAwareDevice{cad}
|
||||
r.cond.Broadcast()
|
||||
|
||||
r.mu.Unlock()
|
||||
|
||||
if toClose != nil {
|
||||
if err := toClose.Close(); err != nil {
|
||||
log.Debugf("error closing last device: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) waitForDevice() bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for len(r.devices) == 0 && !r.closed.Load() {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return !r.closed.Load()
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) peekLast() *closeAwareDevice {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(r.devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.devices[len(r.devices)-1]
|
||||
}
|
||||
@@ -21,6 +21,5 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
RenewTun(fd int) error
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -24,7 +24,3 @@ func (w *WGIface) Create() error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on non mobile")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
return fmt.Errorf("this function has not been implemented on non-android")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
// todo: review does this function really necessary or can we merge it with iOS
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
@@ -23,9 +22,3 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
||||
func (w *WGIface) Create() error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.tun.RenewTun(fd)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,3 @@ func (w *WGIface) Create() error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
return fmt.Errorf("this function has not been implemented on this platform")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -10,13 +9,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// keep darwin compatibility
|
||||
@@ -41,7 +40,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
addr := "100.64.0.1/8"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -124,7 +123,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
||||
func Test_CreateInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||
wgIP := "10.99.99.1/32"
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -167,7 +166,7 @@ func Test_Close(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -212,7 +211,7 @@ func TestRecreation(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -285,7 +284,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||
wgIP := "10.99.99.5/30"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -340,7 +339,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
func Test_UpdatePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.9/30"
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -410,7 +409,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
func Test_RemovePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.13/30"
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -472,7 +471,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
peer2wgPort := 33200
|
||||
|
||||
keepAlive := 1 * time.Second
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -515,7 +514,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
|
||||
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||
newNet, err = stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
|
||||
}
|
||||
|
||||
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
|
||||
log.Debugf("dialing %s %s", network, addr)
|
||||
conn, err := d.net.Dial(network, addr)
|
||||
if err != nil {
|
||||
log.Warnf("NSDialer.Dial failed: %s", err)
|
||||
log.Debugf("failed to deal connection: %s", err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -13,9 +12,8 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/pion/transport/v3"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -201,7 +199,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,19 +3,12 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
|
||||
)
|
||||
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
mtu uint16
|
||||
@@ -29,12 +22,6 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
if isEBPFDisabled() {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
@@ -60,16 +47,3 @@ func (w *KernelFactory) Free() error {
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
|
||||
func isEBPFDisabled() bool {
|
||||
val := os.Getenv(envDisableEBPFWGProxy)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
@@ -128,34 +128,9 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
if d.providerConfig.LoginHint != "" {
|
||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||
if deviceCode.VerificationURI != "" {
|
||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||
}
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func appendLoginHint(uri, loginHint string) string {
|
||||
if uri == "" || loginHint == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||
return uri
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("login_hint", loginHint)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
return parsedURL.String()
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
|
||||
@@ -60,45 +60,38 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||
//
|
||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
@@ -114,7 +107,5 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||
@@ -48,10 +46,9 @@ type PKCEAuthorizationFlow struct {
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
|
||||
// find the first available redirect URL
|
||||
for _, redirectURL := range config.RedirectURLs {
|
||||
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||
if !isRedirectURLPortUsed(redirectURL) {
|
||||
availableRedirectURL = redirectURL
|
||||
break
|
||||
}
|
||||
@@ -105,16 +102,13 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
switch p.providerConfig.LoginFlag {
|
||||
case common.LoginFlagPromptLogin:
|
||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
case common.LoginFlagMaxAge0:
|
||||
}
|
||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
if p.providerConfig.LoginHint != "" {
|
||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
@@ -195,20 +189,17 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
if authErrorDesc != "" {
|
||||
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
||||
}
|
||||
return nil, fmt.Errorf("authentication failed: %s", authError)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on the state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("authentication failed: Invalid state")
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("authentication failed: missing code")
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
@@ -237,7 +228,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||
@@ -285,22 +276,15 @@ func createCodeChallenge(codeVerifier string) string {
|
||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||
}
|
||||
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse redirect URL: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
port := parsedURL.Port()
|
||||
|
||||
if isPortInExcludedRange(port, excludedRanges) {
|
||||
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
@@ -314,33 +298,6 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
|
||||
return true
|
||||
}
|
||||
|
||||
// excludedPortRange represents a range of excluded ports.
|
||||
type excludedPortRange struct {
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||
if len(excludedRanges) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
log.Debugf("invalid port number %s: %v", port, err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, r := range excludedRanges {
|
||||
if portNum >= r.start && portNum <= r.end {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package auth
|
||||
|
||||
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||
return nil
|
||||
}
|
||||
@@ -2,11 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -23,28 +20,22 @@ func TestPromptLogin(t *testing.T) {
|
||||
name string
|
||||
loginFlag mgm.LoginFlag
|
||||
disablePromptLogin bool
|
||||
expectContains []string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPromptLogin,
|
||||
expectContains: []string{promptLogin},
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
expect: promptLogin,
|
||||
},
|
||||
{
|
||||
name: "Max age 0",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expectContains: []string{maxAge0},
|
||||
name: "Max age 0 login",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expect: maxAge0,
|
||||
},
|
||||
{
|
||||
name: "Disable prompt login",
|
||||
loginFlag: mgm.LoginFlagPromptLogin,
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
disablePromptLogin: true,
|
||||
expectContains: []string{},
|
||||
},
|
||||
{
|
||||
name: "None flag should not add parameters",
|
||||
loginFlag: mgm.LoginFlagNone,
|
||||
expectContains: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -59,7 +50,6 @@ func TestPromptLogin(t *testing.T) {
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
LoginFlag: tc.loginFlag,
|
||||
DisablePromptLogin: tc.disablePromptLogin,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
@@ -70,153 +60,12 @@ func TestPromptLogin(t *testing.T) {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tc.expectContains {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||
if !tc.disablePromptLogin {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||
} else {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPortInExcludedRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port string
|
||||
excludedRanges []excludedPortRange
|
||||
expectedBlocked bool
|
||||
}{
|
||||
{
|
||||
name: "Port in excluded range",
|
||||
port: "8080",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port at start of range",
|
||||
port: "8000",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port at end of range",
|
||||
port: "8100",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port before range",
|
||||
port: "7999",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Port after range",
|
||||
port: "8101",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded ranges",
|
||||
port: "8080",
|
||||
excludedRanges: []excludedPortRange{},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Nil excluded ranges",
|
||||
port: "8080",
|
||||
excludedRanges: nil,
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - port in second range",
|
||||
port: "9050",
|
||||
excludedRanges: []excludedPortRange{
|
||||
{start: 8000, end: 8100},
|
||||
{start: 9000, end: 9100},
|
||||
},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - port not in any range",
|
||||
port: "8500",
|
||||
excludedRanges: []excludedPortRange{
|
||||
{start: 8000, end: 8100},
|
||||
{start: 9000, end: 9100},
|
||||
},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid port string",
|
||||
port: "invalid",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Empty port string",
|
||||
port: "",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURL string
|
||||
excludedRanges []excludedPortRange
|
||||
expectedUsed bool
|
||||
}{
|
||||
{
|
||||
name: "Port in excluded range",
|
||||
redirectURL: "http://127.0.0.1:8080/",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedUsed: true,
|
||||
},
|
||||
{
|
||||
name: "Port actually in use",
|
||||
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||
excludedRanges: nil,
|
||||
expectedUsed: true,
|
||||
},
|
||||
{
|
||||
name: "Port not in use and not excluded",
|
||||
redirectURL: "http://127.0.0.1:65432/",
|
||||
excludedRanges: nil,
|
||||
expectedUsed: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid URL without port",
|
||||
redirectURL: "not-a-valid-url",
|
||||
excludedRanges: nil,
|
||||
expectedUsed: false,
|
||||
},
|
||||
{
|
||||
name: "Port excluded even if not in use",
|
||||
redirectURL: "http://127.0.0.1:8050/",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedUsed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||
ranges, err := getExcludedPortRangesFromNetsh()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ranges
|
||||
}
|
||||
|
||||
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("netsh command: %w", err)
|
||||
}
|
||||
|
||||
return parseExcludedPortRanges(string(output))
|
||||
}
|
||||
|
||||
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||
var ranges []excludedPortRange
|
||||
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||
|
||||
foundHeader := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||
foundHeader = true
|
||||
continue
|
||||
}
|
||||
|
||||
if !foundHeader {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(line, "----------") {
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
startPort, err := strconv.Atoi(fields[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
endPort, err := strconv.Atoi(fields[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan output: %w", err)
|
||||
}
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
netshOutput string
|
||||
expectedRanges []excludedPortRange
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid netsh output with multiple ranges",
|
||||
netshOutput: `
|
||||
Protocol tcp Dynamic Port Range
|
||||
---------------------------------
|
||||
Start Port : 49152
|
||||
Number of Ports : 16384
|
||||
|
||||
Protocol tcp Excluded Port Ranges
|
||||
---------------------------------
|
||||
Start Port End Port
|
||||
---------- --------
|
||||
5357 5357 *
|
||||
50000 50059 *
|
||||
`,
|
||||
expectedRanges: []excludedPortRange{
|
||||
{start: 5357, end: 5357},
|
||||
{start: 50000, end: 50059},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty output",
|
||||
netshOutput: `
|
||||
Protocol tcp Dynamic Port Range
|
||||
---------------------------------
|
||||
Start Port : 49152
|
||||
Number of Ports : 16384
|
||||
`,
|
||||
expectedRanges: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
netshOutput: `
|
||||
Protocol tcp Excluded Port Ranges
|
||||
---------------------------------
|
||||
Start Port End Port
|
||||
---------- --------
|
||||
8080 8090
|
||||
`,
|
||||
expectedRanges: []excludedPortRange{
|
||||
{start: 8080, end: 8090},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRanges, ranges)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
ranges := getSystemExcludedPortRanges()
|
||||
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||
|
||||
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener1.Close()
|
||||
}()
|
||||
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
Scope: "openid email profile",
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{
|
||||
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||
},
|
||||
UseIDToken: true,
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, flow)
|
||||
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||
"Should skip port in use and select available port")
|
||||
}
|
||||
@@ -24,14 +24,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -43,13 +39,11 @@ import (
|
||||
)
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
doInitialAutoUpdate bool
|
||||
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
|
||||
persistSyncResponse bool
|
||||
}
|
||||
@@ -58,15 +52,13 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||
engineMutex: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +74,6 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
@@ -91,7 +82,6 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
@@ -170,33 +160,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return err
|
||||
}
|
||||
|
||||
var path string
|
||||
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||
// On mobile, use the provided state file path directly
|
||||
if !fileExists(mobileDependency.StateFilePath) {
|
||||
if err := createFile(mobileDependency.StateFilePath); err != nil {
|
||||
log.Errorf("failed to create state file: %v", err)
|
||||
// we are not exiting as we can run without the state manager
|
||||
}
|
||||
}
|
||||
path = mobileDependency.StateFilePath
|
||||
} else {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path = sm.GetStatePath()
|
||||
}
|
||||
stateManager := statemanager.New(path)
|
||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||
if err == nil {
|
||||
updateManager.CheckUpdateSuccess(c.ctx)
|
||||
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
@@ -308,25 +271,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||
if c.doInitialAutoUpdate {
|
||||
log.Infof("start engine by ui, run auto-update check")
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
c.doInitialAutoUpdate = false
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
@@ -338,14 +291,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
// todo: consider to remove this condition. Is not thread safe.
|
||||
// We should always call Stop(), but we need to verify that it is idempotent
|
||||
if engine.wgInterface != nil {
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
@@ -420,19 +371,6 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
||||
return syncResponse, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager != nil {
|
||||
fwManager.SetLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -45,8 +44,6 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
@@ -57,7 +54,6 @@ block.prof: Block profiling information.
|
||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||
allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -111,9 +107,6 @@ go tool pprof -http=:8088 heap.prof
|
||||
|
||||
This will open a web browser tab with the profiling information.
|
||||
|
||||
Stack Trace
|
||||
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
|
||||
|
||||
Routes
|
||||
The routes.txt file contains detailed routing table information in a tabular format:
|
||||
|
||||
@@ -191,20 +184,6 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||
|
||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||
|
||||
DNS Configuration
|
||||
The debug bundle includes platform-specific DNS configuration files:
|
||||
|
||||
resolv.conf (Unix systems):
|
||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||
- Includes nameserver entries, search domains, and resolver options
|
||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||
|
||||
scutil_dns.txt (macOS only):
|
||||
- Contains detailed DNS configuration from scutil --dns
|
||||
- Shows DNS configuration for all network interfaces
|
||||
- Includes search domains, nameservers, and DNS resolver settings
|
||||
- All IP addresses and domain names are anonymized
|
||||
`
|
||||
|
||||
const (
|
||||
@@ -332,10 +311,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSyncResponse(); err != nil {
|
||||
return fmt.Errorf("add sync response: %w", err)
|
||||
}
|
||||
@@ -363,10 +338,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add systemd logs: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addUpdateLogs(); err != nil {
|
||||
log.Errorf("failed to add updater logs: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -386,10 +357,6 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
if err := g.addFirewallRules(); err != nil {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addReadme() error {
|
||||
@@ -535,18 +502,6 @@ func (g *BundleGenerator) addProf() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
stackTrace := bytes.NewReader(buf[:n])
|
||||
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
|
||||
return fmt.Errorf("add stack trace file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addInterfaces() error {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
@@ -655,29 +610,6 @@ func (g *BundleGenerator) addStateFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addUpdateLogs() error {
|
||||
inst := installer.New()
|
||||
logFiles := inst.LogFiles()
|
||||
if len(logFiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("adding updater logs")
|
||||
for _, logFile := range logFiles {
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
log.Warnf("failed to read update log file %s: %v", logFile, err)
|
||||
continue
|
||||
}
|
||||
|
||||
baseName := filepath.Base(logFile)
|
||||
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
|
||||
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
pattern := sm.GetStatePath()
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addScutilDNS(); err != nil {
|
||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addScutilDNS() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(output)) == 0 {
|
||||
return fmt.Errorf("no scutil DNS output")
|
||||
}
|
||||
|
||||
content := string(output)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -507,13 +507,15 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
|
||||
if p.Base == expr.PayloadBaseNetworkHeader {
|
||||
switch p.Offset {
|
||||
case 12:
|
||||
switch p.Len {
|
||||
case 4, 2:
|
||||
if p.Len == 4 {
|
||||
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
} else if p.Len == 2 {
|
||||
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
}
|
||||
case 16:
|
||||
switch p.Len {
|
||||
case 4, 2:
|
||||
if p.Len == 4 {
|
||||
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
} else if p.Len == 2 {
|
||||
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,3 @@ package debug
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build unix && !darwin && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !unix
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
//go:build unix && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
func (g *BundleGenerator) addResolvConf() error {
|
||||
data, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -38,8 +38,6 @@ type DeviceAuthProviderConfig struct {
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -76,9 +76,6 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.NonAuthoritative {
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
@@ -109,9 +106,8 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||
records := collectPTRRecords(config, network)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
SearchDomainDisabled: true,
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
|
||||
@@ -3,21 +3,17 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
@@ -47,23 +43,7 @@ type HandlerChain struct {
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
origPattern string
|
||||
requestID string
|
||||
shouldContinue bool
|
||||
response *dns.Msg
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
// RequestID returns the request ID for tracing
|
||||
func (w *ResponseWriterChain) RequestID() string {
|
||||
return w.requestID
|
||||
}
|
||||
|
||||
// SetMeta sets a metadata key-value pair for logging
|
||||
func (w *ResponseWriterChain) SetMeta(key, value string) {
|
||||
if w.meta == nil {
|
||||
w.meta = make(map[string]string)
|
||||
}
|
||||
w.meta[key] = value
|
||||
}
|
||||
|
||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
@@ -72,7 +52,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
w.shouldContinue = true
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -122,8 +101,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
|
||||
c.logHandlers()
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
@@ -163,109 +140,68 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
c.logHandlers()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// logHandlers logs the current handler chain state. Caller must hold the lock.
|
||||
func (c *HandlerChain) logHandlers() {
|
||||
if !log.IsLevelEnabled(log.TraceLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
|
||||
for _, h := range c.handlers {
|
||||
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
|
||||
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
|
||||
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
|
||||
" priority=" + strconv.Itoa(h.Priority) + "\n")
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
|
||||
c.mu.RLock()
|
||||
handlers := slices.Clone(c.handlers)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||
for _, h := range handlers {
|
||||
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
matched := c.isHandlerMatch(qname, entry)
|
||||
|
||||
handlerName := entry.OrigPattern
|
||||
if s, ok := entry.Handler.(interface{ String() string }); ok {
|
||||
handlerName = s.String()
|
||||
}
|
||||
if matched {
|
||||
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
|
||||
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
requestID: requestID,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
logger.Tracef("handler requested continue for domain=%s", qname)
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
c.logResponse(logger, chainWriter, qname, startTime)
|
||||
return
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
// Only log continue for non-management cache handlers to reduce noise
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
logger.Tracef("no handler found for domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
log.Tracef("no handler found for domain=%s", qname)
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeRefused)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
|
||||
if cw.response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var meta string
|
||||
for k, v := range cw.meta {
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
|
||||
@@ -11,6 +11,11 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa."
|
||||
ipv6ReverseZone = ".ip6.arpa."
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
@@ -105,9 +110,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
MatchOnly: customZone.SearchDomainDisabled,
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,52 +1,30 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const externalResolutionTimeout = 4 * time.Second
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
domains map[domain.Domain]struct{}
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
domains: make(map[domain.Domain]struct{}),
|
||||
zones: make(map[domain.Domain]bool),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,18 +37,7 @@ func (d *Resolver) String() string {
|
||||
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
||||
}
|
||||
|
||||
func (d *Resolver) Stop() {
|
||||
if d.cancel != nil {
|
||||
d.cancel()
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
}
|
||||
func (d *Resolver) Stop() {}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *Resolver) ID() types.HandlerID {
|
||||
@@ -81,85 +48,35 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
logger.Debug("received local resolver request with no question")
|
||||
log.Debugf("received local resolver request with no question")
|
||||
return
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
result := d.lookupRecords(logger, question)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
|
||||
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
|
||||
d.continueToNext(logger, w, r)
|
||||
return
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(question)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// Check if we have any records for this domain name with different types
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
logger.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// determineRcode returns the appropriate DNS response code.
|
||||
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
|
||||
// even if CNAME records are included in the answer.
|
||||
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
|
||||
// Use the rcode from lookup - this properly handles CNAME chains where
|
||||
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
|
||||
if result.rcode != 0 {
|
||||
return result.rcode
|
||||
}
|
||||
|
||||
// No records found, but domain exists with different record types (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
// findZone finds the matching zone for a query name using reverse suffix lookup.
|
||||
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
|
||||
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
|
||||
qname = strings.ToLower(dns.Fqdn(qname))
|
||||
for {
|
||||
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
|
||||
return nonAuth, true
|
||||
}
|
||||
// Move to parent domain
|
||||
idx := strings.Index(qname, ".")
|
||||
if idx == -1 || idx == len(qname)-1 {
|
||||
return false, false
|
||||
}
|
||||
qname = qname[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// shouldFallthrough checks if the query should fallthrough to the next handler.
|
||||
// Returns true if the queried name belongs to a non-authoritative zone.
|
||||
func (d *Resolver) shouldFallthrough(qname string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
nonAuth, found := d.findZone(qname)
|
||||
return found && nonAuth
|
||||
}
|
||||
|
||||
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Warnf("failed to write continue signal: %v", err)
|
||||
log.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,27 +89,8 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||
return exists
|
||||
}
|
||||
|
||||
// isInManagedZone checks if the given name falls within any of our managed zones.
|
||||
// This is used to avoid unnecessary external resolution for CNAME targets that
|
||||
// are within zones we manage - if we don't have a record for it, it doesn't exist.
|
||||
// Caller must NOT hold the lock.
|
||||
func (d *Resolver) isInManagedZone(name string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, found := d.findZone(name)
|
||||
return found
|
||||
}
|
||||
|
||||
// lookupResult contains the result of a DNS lookup operation.
|
||||
type lookupResult struct {
|
||||
records []dns.RR
|
||||
rcode int
|
||||
hasExternalData bool
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
|
||||
@@ -200,14 +98,10 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
|
||||
d.mu.RUnlock()
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
cnameQuestion := dns.Question{
|
||||
Name: question.Name,
|
||||
Qtype: dns.TypeCNAME,
|
||||
Qclass: question.Qclass,
|
||||
}
|
||||
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
|
||||
question.Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(question)
|
||||
}
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
return nil
|
||||
}
|
||||
|
||||
recordsCopy := slices.Clone(records)
|
||||
@@ -225,178 +119,20 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||
return recordsCopy
|
||||
}
|
||||
|
||||
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||
// the final resolved record of the requested type. This is required for musl libc
|
||||
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
|
||||
const maxDepth = 8
|
||||
var chain []dns.RR
|
||||
|
||||
for range maxDepth {
|
||||
cnameRecords := d.getRecords(cnameQuestion)
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
chain = append(chain, cnameRecords...)
|
||||
|
||||
cname, ok := cnameRecords[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
targetName := strings.ToLower(cname.Target)
|
||||
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
|
||||
|
||||
// keep following chain
|
||||
if targetResult.rcode == -1 {
|
||||
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
|
||||
continue
|
||||
}
|
||||
|
||||
return d.buildChainResult(chain, targetResult)
|
||||
}
|
||||
|
||||
if len(chain) > 0 {
|
||||
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// buildChainResult combines CNAME chain records with the target resolution result.
|
||||
// Per RFC 6604, the final rcode is propagated through the chain.
|
||||
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
|
||||
records := chain
|
||||
if len(target.records) > 0 {
|
||||
records = append(records, target.records...)
|
||||
}
|
||||
|
||||
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
|
||||
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: dns.RcodeServerFailure,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: target.rcode,
|
||||
hasExternalData: target.hasExternalData,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCNAMETarget attempts to resolve a CNAME target name.
|
||||
// Returns rcode=-1 to signal "keep following the chain".
|
||||
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
|
||||
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
|
||||
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// another CNAME, keep following
|
||||
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
|
||||
return lookupResult{rcode: -1}
|
||||
}
|
||||
|
||||
// domain exists locally but not this record type (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(targetName)) {
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// in our zone but doesn't exist (NXDOMAIN)
|
||||
if d.isInManagedZone(targetName) {
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
}
|
||||
|
||||
return d.resolveExternal(logger, targetName, targetType)
|
||||
}
|
||||
|
||||
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.records[q]
|
||||
}
|
||||
|
||||
func (d *Resolver) hasRecord(q dns.Question) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
_, ok := d.records[q]
|
||||
return ok
|
||||
}
|
||||
|
||||
// resolveExternal resolves a domain name using the system resolver.
|
||||
// This is used to resolve CNAME targets that point outside our local zone,
|
||||
// which is required for musl libc compatibility (musl expects complete answers).
|
||||
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return lookupResult{rcode: dns.RcodeNotImplemented}
|
||||
}
|
||||
|
||||
resolver := d.resolver
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
|
||||
if result.Err != nil {
|
||||
d.logDNSError(logger, name, qtype, result.Err)
|
||||
return lookupResult{rcode: result.Rcode, hasExternalData: true}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: resutil.IPsToRRs(name, result.IPs, 60),
|
||||
rcode: dns.RcodeSuccess,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
// logDNSError logs DNS resolution errors for debugging.
|
||||
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
|
||||
qtypeName := dns.TypeToString[qtype]
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.IsNotFound {
|
||||
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
|
||||
} else {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
|
||||
for _, zone := range customZones {
|
||||
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
||||
d.zones[zoneDomain] = zone.NonAuthoritative
|
||||
|
||||
for _, rec := range zone.Records {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
}
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,18 +12,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
// mockResolver implements resolver for testing
|
||||
type mockResolver struct {
|
||||
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if m.lookupFunc != nil {
|
||||
return m.lookupFunc(ctx, network, host)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -124,11 +106,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
|
||||
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
|
||||
update1 := []nbdns.SimpleRecord{record1}
|
||||
update2 := []nbdns.SimpleRecord{record2}
|
||||
|
||||
// Apply first update
|
||||
resolver.Update(zone1)
|
||||
resolver.Update(update1)
|
||||
|
||||
// Verify first update
|
||||
resolver.mu.RLock()
|
||||
@@ -140,7 +122,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||
|
||||
// Apply second update
|
||||
resolver.Update(zone2)
|
||||
resolver.Update(update2)
|
||||
|
||||
// Verify second update
|
||||
resolver.mu.RLock()
|
||||
@@ -169,10 +151,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||
}
|
||||
|
||||
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
|
||||
update := []nbdns.SimpleRecord{record1, record2}
|
||||
|
||||
// Apply update with both records
|
||||
resolver.Update(zones)
|
||||
resolver.Update(update)
|
||||
|
||||
// Create question that matches both records
|
||||
question := dns.Question{
|
||||
@@ -213,10 +195,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||
}
|
||||
|
||||
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
|
||||
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||
|
||||
// Apply update with all three records
|
||||
resolver.Update(zones)
|
||||
resolver.Update(update)
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||
|
||||
@@ -282,7 +264,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update resolver with the records
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
|
||||
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -397,7 +379,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update resolver with both records
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
|
||||
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -494,20 +476,6 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
// with 0 records instead of NXDOMAIN
|
||||
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
// Mock external resolver for CNAME target resolution
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "target.example.com." {
|
||||
if network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
if network == "ip6" {
|
||||
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||
}
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
}
|
||||
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "example.netbird.cloud.",
|
||||
@@ -525,7 +493,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
RData: "target.example.com.",
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
|
||||
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -614,808 +582,3 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
|
||||
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
|
||||
t.Run("simple internal CNAME chain", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
|
||||
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 2)
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "target.example.com.", cname.Target)
|
||||
|
||||
a, ok := resp.Answer[1].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "192.168.1.1", a.A.String())
|
||||
})
|
||||
|
||||
t.Run("multi-hop CNAME chain", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
|
||||
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
|
||||
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3)
|
||||
})
|
||||
|
||||
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
|
||||
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
|
||||
t.Run("chain at max depth resolves", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
var records []nbdns.SimpleRecord
|
||||
// Create chain of 7 CNAMEs (under max of 8)
|
||||
for i := 1; i <= 7; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("hop%d.test.", i),
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("hop%d.test.", i+1),
|
||||
})
|
||||
}
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||
})
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 8)
|
||||
})
|
||||
|
||||
t.Run("chain exceeding max depth stops", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
var records []nbdns.SimpleRecord
|
||||
// Create chain of 10 CNAMEs (exceeds max of 8)
|
||||
for i := 1; i <= 10; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("deep%d.test.", i),
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("deep%d.test.", i+1),
|
||||
})
|
||||
}
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||
})
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
// Should NOT have the final A record (chain too deep)
|
||||
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||
})
|
||||
|
||||
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
|
||||
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
|
||||
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
|
||||
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "external.example.com.", cname.Target)
|
||||
|
||||
a, ok := resp.Answer[1].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "93.184.216.34", a.A.String())
|
||||
})
|
||||
|
||||
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip6" {
|
||||
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "external.example.com.", cname.Target)
|
||||
|
||||
aaaa, ok := resp.Answer[1].(*dns.AAAA)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
|
||||
})
|
||||
|
||||
t.Run("concurrent external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]*dns.Msg, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
results[idx] = resp
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, resp := range results {
|
||||
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
|
||||
func TestLocalResolver_ZoneManagement(t *testing.T) {
|
||||
t.Run("Update sets zones correctly", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{
|
||||
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
}},
|
||||
{Domain: "test.local."},
|
||||
})
|
||||
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("other.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("sub.test.local."))
|
||||
assert.False(t, resolver.isInManagedZone("external.com."))
|
||||
})
|
||||
|
||||
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
|
||||
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
|
||||
})
|
||||
|
||||
t.Run("Update clears zones", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
|
||||
resolver.Update(nil)
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
|
||||
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
|
||||
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
|
||||
})
|
||||
|
||||
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.other.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
|
||||
})
|
||||
|
||||
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
|
||||
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "Answer should be CNAME record")
|
||||
})
|
||||
|
||||
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." {
|
||||
if network == "ip6" {
|
||||
// No AAAA records
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
}
|
||||
if network == "ip4" {
|
||||
// But A records exist - domain exists
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should have only CNAME")
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "Answer should be CNAME record")
|
||||
})
|
||||
|
||||
// Table-driven test for all external resolution outcomes
|
||||
externalCases := []struct {
|
||||
name string
|
||||
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
|
||||
expectedRcode int
|
||||
expectedAnswer int
|
||||
}{
|
||||
{
|
||||
name: "external NXDOMAIN (both A and AAAA not found)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (temporary error)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsTemporary: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (timeout)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsTimeout: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (generic error)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external success with IPs",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectedAnswer: 2, // CNAME + A
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range externalCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
|
||||
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
|
||||
if tc.expectedAnswer > 0 {
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "first answer should be CNAME")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
|
||||
// trigger fallthrough (Zero bit set) when no records match
|
||||
func TestLocalResolver_Fallthrough(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
record := nbdns.SimpleRecord{
|
||||
Name: "existing.custom.zone.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
zones []nbdns.CustomZone
|
||||
queryName string
|
||||
expectFallthrough bool
|
||||
expectRecord bool
|
||||
}{
|
||||
{
|
||||
name: "Authoritative zone returns NXDOMAIN without fallthrough",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
}},
|
||||
queryName: "nonexistent.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Non-authoritative zone triggers fallthrough",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
NonAuthoritative: true,
|
||||
}},
|
||||
queryName: "nonexistent.custom.zone.",
|
||||
expectFallthrough: true,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Record found in non-authoritative zone returns normally",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
NonAuthoritative: true,
|
||||
}},
|
||||
queryName: "existing.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: true,
|
||||
},
|
||||
{
|
||||
name: "Record found in authoritative zone returns normally",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
}},
|
||||
queryName: "existing.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver.Update(tc.zones)
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG, "Should have received a response")
|
||||
|
||||
if tc.expectFallthrough {
|
||||
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
|
||||
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
|
||||
} else {
|
||||
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
|
||||
}
|
||||
|
||||
if tc.expectRecord {
|
||||
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
|
||||
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
t.Run("direct record lookup is authoritative", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.Authoritative)
|
||||
})
|
||||
|
||||
t.Run("external resolution is not authoritative", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2)
|
||||
assert.False(t, resp.Authoritative)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
resolver.Stop()
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Len(t, resp.Answer, 0)
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
resolver.Stop()
|
||||
resolver.Stop()
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
lookupCtxCanceled := make(chan struct{})
|
||||
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
close(lookupStarted)
|
||||
<-ctx.Done()
|
||||
close(lookupCtxCanceled)
|
||||
return nil, ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-lookupStarted
|
||||
resolver.Stop()
|
||||
|
||||
select {
|
||||
case <-lookupCtxCanceled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("external lookup context was not canceled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ServeDNS did not return after Stop")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
|
||||
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "EXAMPLE.COM.",
|
||||
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG)
|
||||
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
|
||||
func BenchmarkFindZone_BestCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Single zone that matches immediately
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough("example.com.")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
|
||||
func BenchmarkFindZone_WorstCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// 100 zones that won't match
|
||||
var zones []nbdns.CustomZone
|
||||
for i := 0; i < 100; i++ {
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("zone%d.internal.", i),
|
||||
NonAuthoritative: true,
|
||||
})
|
||||
}
|
||||
resolver.Update(zones)
|
||||
|
||||
// Query with many labels that won't match any zone
|
||||
qname := "a.b.c.d.e.f.g.h.external.com."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
|
||||
func BenchmarkFindZone_TypicalCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
|
||||
resolver.Update([]nbdns.CustomZone{
|
||||
{Domain: "netbird.cloud.", NonAuthoritative: false},
|
||||
{Domain: "custom.local.", NonAuthoritative: true},
|
||||
})
|
||||
|
||||
// Query for subdomain of user zone
|
||||
qname := "myhost.custom.local."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
|
||||
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
var zones []nbdns.CustomZone
|
||||
for i := 0; i < 100; i++ {
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("zone%d.internal.", i),
|
||||
})
|
||||
}
|
||||
resolver.Update(zones)
|
||||
|
||||
// Query that matches zone50
|
||||
qname := "host.zone50.internal."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user