mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-09 19:16:07 -04:00
Compare commits
241 Commits
nmap/clean
...
poc-token-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b016a1f0d0 | ||
|
|
c009055693 | ||
|
|
14181c909c | ||
|
|
a05dc3823d | ||
|
|
7d19bdf085 | ||
|
|
a1b048f2ad | ||
|
|
0bd227196e | ||
|
|
eea7687ddf | ||
|
|
57d3ee5aac | ||
|
|
cfdfdecc14 | ||
|
|
ac995bae6d | ||
|
|
41a5509ce0 | ||
|
|
db5e26db94 | ||
|
|
fe975fb834 | ||
|
|
e368d2995b | ||
|
|
a3241d8376 | ||
|
|
6dfc5772ba | ||
|
|
f70925178c | ||
|
|
9554934b92 | ||
|
|
7fdb824a37 | ||
|
|
412407adc0 | ||
|
|
e0874d7de7 | ||
|
|
8df1536cbb | ||
|
|
fcbacc62ec | ||
|
|
ee2ae45653 | ||
|
|
6f2f0f9ae4 | ||
|
|
c37ebc6fb3 | ||
|
|
23abb5743c | ||
|
|
b87aa0bc15 | ||
|
|
f1a65d732d | ||
|
|
a3c0ea3e71 | ||
|
|
abaf061c2a | ||
|
|
e531fb54b1 | ||
|
|
5fcfed5b16 | ||
|
|
5f43449f67 | ||
|
|
6796601aa6 | ||
|
|
1fc25c301b | ||
|
|
08ae281b2d | ||
|
|
bd47f44c63 | ||
|
|
381260911b | ||
|
|
38db42e7d6 | ||
|
|
5d606d909d | ||
|
|
d689718b50 | ||
|
|
54a73c6649 | ||
|
|
418377842e | ||
|
|
15ef56e03d | ||
|
|
917035f8e8 | ||
|
|
963e3f5457 | ||
|
|
e20b969188 | ||
|
|
1c7059ee67 | ||
|
|
22a3365658 | ||
|
|
08ab1e3478 | ||
|
|
ebb1f4007d | ||
|
|
acb53ece93 | ||
|
|
e020950cfd | ||
|
|
9dba262a20 | ||
|
|
5bcdf36377 | ||
|
|
1ffe8deb10 | ||
|
|
d069145bd1 | ||
|
|
f3493ee042 | ||
|
|
bf48044e5c | ||
|
|
fb4cc37a4a | ||
|
|
55b8d89a79 | ||
|
|
6968a32a5a | ||
|
|
cfe6753349 | ||
|
|
5ae15b3af3 | ||
|
|
b79adb706c | ||
|
|
f22497d5da | ||
|
|
95d672c9df | ||
|
|
7d08a609e6 | ||
|
|
eea6120cd0 | ||
|
|
0cb02bd906 | ||
|
|
08d3867f41 | ||
|
|
b16d63643c | ||
|
|
940d01bdea | ||
|
|
ba9158d159 | ||
|
|
ca9a7e11ef | ||
|
|
a803f47685 | ||
|
|
79fed32f01 | ||
|
|
6b00bb0a66 | ||
|
|
e2adef1eea | ||
|
|
9e5fa11792 | ||
|
|
1ff75acb31 | ||
|
|
1754160686 | ||
|
|
423f6266fb | ||
|
|
16d1b4a14a | ||
|
|
7c14056faf | ||
|
|
62e37dc2e2 | ||
|
|
6a08695ee8 | ||
|
|
9a67a8e427 | ||
|
|
73aa0785ba | ||
|
|
53c1016a8e | ||
|
|
fd442138e6 | ||
|
|
be5f30225a | ||
|
|
7467e9fb8c | ||
|
|
2390c2e46e | ||
|
|
778c223176 | ||
|
|
36cd0dd85c | ||
|
|
09a1d5a02d | ||
|
|
7c996ac9b5 | ||
|
|
cf9fd5d960 | ||
|
|
1c5ab7cb8f | ||
|
|
aaad3b25a7 | ||
|
|
9904235a2f | ||
|
|
780e9f57a5 | ||
|
|
a8db73285b | ||
|
|
3b43c00d12 | ||
|
|
2f390e1794 | ||
|
|
3630ebb3ae | ||
|
|
260c46df04 | ||
|
|
7f11e3205d | ||
|
|
1c8f92a96f | ||
|
|
7b6294b624 | ||
|
|
156d0b1fef | ||
|
|
2cf00dba58 | ||
|
|
d2a7f3ae36 | ||
|
|
6a64d4e4dd | ||
|
|
51e63c246b | ||
|
|
99e6b1eda4 | ||
|
|
dc26a5a436 | ||
|
|
3883b2fb41 | ||
|
|
ed58659a01 | ||
|
|
5190923c70 | ||
|
|
7c647dd160 | ||
|
|
07e59b2708 | ||
|
|
0a3a9f977d | ||
|
|
2f263bf7e6 | ||
|
|
f65f4fc280 | ||
|
|
adbd7ab4c3 | ||
|
|
0419834482 | ||
|
|
f797d2d9cb | ||
|
|
5ae7efe8f7 | ||
|
|
d6e35bd0fe | ||
|
|
0e00f1c8f7 | ||
|
|
4433f44a12 | ||
|
|
7504e718d7 | ||
|
|
9b0387e7ee | ||
|
|
5ccce1ab3f | ||
|
|
e366fe340e | ||
|
|
b01809f8e3 | ||
|
|
790ef39187 | ||
|
|
3af16cf333 | ||
|
|
d09c69f303 | ||
|
|
096d4ac529 | ||
|
|
8fafde614a | ||
|
|
694ae13418 | ||
|
|
b5b7dd4f53 | ||
|
|
476785b122 | ||
|
|
907677f835 | ||
|
|
7d844b9410 | ||
|
|
eeabc64a73 | ||
|
|
5da2b0fdcc | ||
|
|
a0005a604e | ||
|
|
a89bb807a6 | ||
|
|
28f3354ffa | ||
|
|
562923c600 | ||
|
|
0dd0c67b3b | ||
|
|
ca33849f31 | ||
|
|
18cd0f1480 | ||
|
|
b02982f6b1 | ||
|
|
4d89ae27ef | ||
|
|
733ea77c5c | ||
|
|
92f72bfce6 | ||
|
|
bffb25bea7 | ||
|
|
3af4543e80 | ||
|
|
146774860b | ||
|
|
5243481316 | ||
|
|
76a39c1dcb | ||
|
|
02ce918114 | ||
|
|
30cfc22cb6 | ||
|
|
3168afbfcb | ||
|
|
a73ee47557 | ||
|
|
fa6ff005f2 | ||
|
|
095379fa60 | ||
|
|
30572fe1b8 | ||
|
|
3a6f364b03 | ||
|
|
5345d716ee | ||
|
|
f882c36e0a | ||
|
|
e95cfa1a00 | ||
|
|
0d480071b6 | ||
|
|
8e0b7b6c25 | ||
|
|
f204da0d68 | ||
|
|
7d74904d62 | ||
|
|
760ac5e07d | ||
|
|
4352228797 | ||
|
|
74c770609c | ||
|
|
f4ca36ed7e | ||
|
|
c86da92fc6 | ||
|
|
3f0c577456 | ||
|
|
717da8c7b7 | ||
|
|
a0a61d4f47 | ||
|
|
5b1fced872 | ||
|
|
c98dcf5ef9 | ||
|
|
57cb6bfccb | ||
|
|
95bf97dc3c | ||
|
|
3d116c9d33 | ||
|
|
a9ce9f8d5a | ||
|
|
10b981a855 | ||
|
|
7700b4333d | ||
|
|
7d0131111e | ||
|
|
1daea35e4b | ||
|
|
f97544af0d | ||
|
|
231e80cc15 | ||
|
|
a4c1362bff | ||
|
|
b611d4a751 | ||
|
|
2c9decfa55 | ||
|
|
3c5ac17e2f | ||
|
|
ae42bbb898 | ||
|
|
b86722394b | ||
|
|
a103f69767 | ||
|
|
73fbb3fc62 | ||
|
|
7b3523e25e | ||
|
|
6e4e1386e7 | ||
|
|
671e9af6eb | ||
|
|
50f42caf94 | ||
|
|
b7eeefc102 | ||
|
|
8dd22f3a4f | ||
|
|
4b89427447 | ||
|
|
b71e2860cf | ||
|
|
160b27bc60 | ||
|
|
c084386b88 | ||
|
|
6889047350 | ||
|
|
245bbb4acf | ||
|
|
2b2fc02d83 | ||
|
|
703ef29199 | ||
|
|
b0b60b938a | ||
|
|
e3a026bf1c | ||
|
|
94503465ee | ||
|
|
8d959b0abc | ||
|
|
1d8390b935 | ||
|
|
2851e38a1f | ||
|
|
51261fe7a9 | ||
|
|
304321d019 | ||
|
|
f8c3295645 | ||
|
|
183619d1e1 | ||
|
|
3b832d1f21 | ||
|
|
fcb849698f | ||
|
|
7527e0ebdb | ||
|
|
ed5f98da5b | ||
|
|
12b38e25da | ||
|
|
626e892e3b |
14
.github/ISSUE_TEMPLATE/config.yml
vendored
14
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,14 +0,0 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Community Support
|
||||
url: https://forum.netbird.io/
|
||||
about: Community support forum
|
||||
- name: Cloud Support
|
||||
url: https://docs.netbird.io/help/report-bug-issues
|
||||
about: Contact us for support
|
||||
- name: Client/Connection Troubleshooting
|
||||
url: https://docs.netbird.io/help/troubleshooting-client
|
||||
about: See our client troubleshooting guide for help addressing common issues
|
||||
- name: Self-host Troubleshooting
|
||||
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||
about: See our self-host troubleshooting guide for help addressing common issues
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
|
||||
51
.github/workflows/golang-test-linux.yml
vendored
51
.github/workflows/golang-test-linux.yml
vendored
@@ -97,16 +97,6 @@ jobs:
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
- name: Build combined
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build combined 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -154,7 +144,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -214,7 +204,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -409,19 +399,12 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
@@ -504,18 +487,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
@@ -596,18 +576,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
|
||||
9
.github/workflows/golang-test-windows.yml
vendored
9
.github/workflows/golang-test-windows.yml
vendored
@@ -63,15 +63,10 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- name: Generate test script
|
||||
run: |
|
||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
51
.github/workflows/pr-title-check.yml
vendored
51
.github/workflows/pr-title-check.yml
vendored
@@ -1,51 +0,0 @@
|
||||
name: PR Title Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
check-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const allowedTags = [
|
||||
'management',
|
||||
'client',
|
||||
'signal',
|
||||
'proxy',
|
||||
'relay',
|
||||
'misc',
|
||||
'infrastructure',
|
||||
'self-hosted',
|
||||
'doc',
|
||||
];
|
||||
|
||||
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||
const match = title.match(pattern);
|
||||
|
||||
if (!match) {
|
||||
core.setFailed(
|
||||
`PR title must start with a tag in brackets.\n` +
|
||||
`Example: [client] fix something\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||
|
||||
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||
if (invalid.length > 0) {
|
||||
core.setFailed(
|
||||
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||
88
.github/workflows/release.yml
vendored
88
.github/workflows/release.yml
vendored
@@ -10,7 +10,7 @@ on:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.1"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
@@ -169,14 +169,6 @@ jobs:
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
@@ -184,7 +176,6 @@ jobs:
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
@@ -194,55 +185,6 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||
rpm --import /tmp/rpm-pub.key
|
||||
echo "=== Verifying RPM signatures ==="
|
||||
for rpm_file in /dist/*amd64*.rpm; do
|
||||
[ -f "$rpm_file" ] || continue
|
||||
echo "--- $(basename $rpm_file) ---"
|
||||
rpm -K "$rpm_file"
|
||||
done
|
||||
'
|
||||
- name: Clean up GPG key
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: Tag and push images (amd64 only)
|
||||
if: |
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||
run: |
|
||||
resolve_tags() {
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
echo "pr-${{ github.event.pull_request.number }}"
|
||||
else
|
||||
echo "main sha-$(git rev-parse --short HEAD)"
|
||||
fi
|
||||
}
|
||||
|
||||
tag_and_push() {
|
||||
local src="$1" img_name tag dst
|
||||
img_name="${src%%:*}"
|
||||
for tag in $(resolve_tags); do
|
||||
dst="${img_name}:${tag}"
|
||||
echo "Tagging ${src} -> ${dst}"
|
||||
docker tag "$src" "$dst"
|
||||
docker push "$dst"
|
||||
done
|
||||
}
|
||||
|
||||
export -f tag_and_push resolve_tags
|
||||
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
tag_and_push "$SRC"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -309,14 +251,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||
|
||||
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
||||
run: |
|
||||
cd /tmp
|
||||
@@ -341,24 +275,6 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||
rpm --import /tmp/rpm-pub.key
|
||||
echo "=== Verifying RPM signatures ==="
|
||||
for rpm_file in /dist/*.rpm; do
|
||||
[ -f "$rpm_file" ] || continue
|
||||
echo "--- $(basename $rpm_file) ---"
|
||||
rpm -K "$rpm_file"
|
||||
done
|
||||
'
|
||||
- name: Clean up GPG key
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 57671680 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
126
.goreleaser.yaml
126
.goreleaser.yaml
@@ -140,40 +140,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-proxy
|
||||
dir: proxy/cmd/proxy
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-proxy
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-idp-migrate
|
||||
dir: tools/idp-migrate
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-idp-migrate
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -186,22 +152,18 @@ archives:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||
format: binary
|
||||
- id: netbird-idp-migrate
|
||||
builds:
|
||||
- netbird-idp-migrate
|
||||
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
id: netbird_deb
|
||||
id: netbird-deb
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
- netbird
|
||||
formats:
|
||||
- deb
|
||||
|
||||
scripts:
|
||||
postinstall: "release_files/post_install.sh"
|
||||
preremove: "release_files/pre_remove.sh"
|
||||
@@ -209,19 +171,16 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
id: netbird_rpm
|
||||
id: netbird-rpm
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
- netbird
|
||||
formats:
|
||||
- rpm
|
||||
|
||||
scripts:
|
||||
postinstall: "release_files/post_install.sh"
|
||||
preremove: "release_files/pre_remove.sh"
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
@@ -630,55 +589,6 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
@@ -859,30 +769,6 @@ docker_manifests:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
@@ -903,7 +789,7 @@ brews:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird_deb
|
||||
- netbird-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
@@ -911,7 +797,7 @@ uploads:
|
||||
|
||||
- name: yum
|
||||
ids:
|
||||
- netbird_rpm
|
||||
- netbird-rpm
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -61,7 +61,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
id: netbird_ui_deb
|
||||
id: netbird-ui-deb
|
||||
package_name: netbird-ui
|
||||
builds:
|
||||
- netbird-ui
|
||||
@@ -80,7 +80,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
id: netbird_ui_rpm
|
||||
id: netbird-ui-rpm
|
||||
package_name: netbird-ui
|
||||
builds:
|
||||
- netbird-ui
|
||||
@@ -95,14 +95,11 @@ nfpms:
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird_ui_deb
|
||||
- netbird-ui-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
@@ -110,7 +107,7 @@ uploads:
|
||||
|
||||
- name: yum
|
||||
ids:
|
||||
- netbird_ui_rpm
|
||||
- netbird-ui-rpm
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
## Contributor License Agreement
|
||||
|
||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
|
||||
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
@@ -126,7 +126,6 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.3
|
||||
FROM alpine:3.23.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
@@ -17,7 +17,8 @@ ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ ENV \
|
||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||
NB_DISABLE_DNS="true" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
pi := PeerInfo{
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
int(p.ConnStatus),
|
||||
p.ConnStatus.String(),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
var (
|
||||
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
|
||||
// EnvKeyNBForceRelay Exported for Android java client
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
|
||||
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||
|
||||
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||
)
|
||||
|
||||
// EnvList wraps a Go map for export to Java
|
||||
|
||||
@@ -2,20 +2,11 @@
|
||||
|
||||
package android
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
// Connection status constants exported via gomobile.
|
||||
const (
|
||||
ConnStatusIdle = int(peer.StatusIdle)
|
||||
ConnStatusConnecting = int(peer.StatusConnecting)
|
||||
ConnStatusConnected = int(peer.StatusConnected)
|
||||
)
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
type PeerInfo struct {
|
||||
IP string
|
||||
FQDN string
|
||||
ConnStatus int
|
||||
ConnStatus string // Todo replace to enum
|
||||
Routes PeerRoutes
|
||||
}
|
||||
|
||||
|
||||
@@ -181,11 +181,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
@@ -199,13 +198,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
needsRestoreUp := false
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = !stateWasDown
|
||||
cmd.Println("netbird down")
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -213,15 +209,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
|
||||
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = false
|
||||
cmd.Println("netbird up")
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("netbird up")
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
@@ -267,28 +261,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if needsRestoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up (restored)")
|
||||
}
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird down")
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
var (
|
||||
exposePin string
|
||||
exposePassword string
|
||||
exposeUserGroups []string
|
||||
exposeDomain string
|
||||
exposeNamePrefix string
|
||||
exposeProtocol string
|
||||
exposeExternalPort uint16
|
||||
)
|
||||
|
||||
var exposeCmd = &cobra.Command{
|
||||
Use: "expose <port>",
|
||||
Short: "Expose a local port via the NetBird reverse proxy",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: ` netbird expose --with-password safe-pass 8080
|
||||
netbird expose --protocol tcp 5432
|
||||
netbird expose --protocol tcp --with-external-port 5433 5432
|
||||
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
|
||||
RunE: exposeFn,
|
||||
}
|
||||
|
||||
func init() {
|
||||
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
|
||||
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
|
||||
}
|
||||
|
||||
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
|
||||
func isClusterProtocol(protocol string) bool {
|
||||
switch strings.ToLower(protocol) {
|
||||
case "tcp", "udp", "tls":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
|
||||
// where domain display doesn't apply. TLS uses SNI so it has a domain.
|
||||
func isPortBasedProtocol(protocol string) bool {
|
||||
switch strings.ToLower(protocol) {
|
||||
case "tcp", "udp":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// extractPort returns the port portion of a URL like "tcp://host:12345", or
|
||||
// falls back to the given default formatted as a string.
|
||||
func extractPort(serviceURL string, fallback uint16) string {
|
||||
u := serviceURL
|
||||
if idx := strings.Index(u, "://"); idx != -1 {
|
||||
u = u[idx+3:]
|
||||
}
|
||||
if i := strings.LastIndex(u, ":"); i != -1 {
|
||||
if p := u[i+1:]; p != "" {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return strconv.FormatUint(uint64(fallback), 10)
|
||||
}
|
||||
|
||||
// resolveExternalPort returns the effective external port, defaulting to the target port.
|
||||
func resolveExternalPort(targetPort uint64) uint16 {
|
||||
if exposeExternalPort != 0 {
|
||||
return exposeExternalPort
|
||||
}
|
||||
return uint16(targetPort)
|
||||
}
|
||||
|
||||
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||
}
|
||||
if port == 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||
}
|
||||
|
||||
if !isProtocolValid(exposeProtocol) {
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||
}
|
||||
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
|
||||
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
|
||||
}
|
||||
} else if cmd.Flags().Changed("with-external-port") {
|
||||
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
|
||||
}
|
||||
|
||||
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||
return 0, fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||
return 0, fmt.Errorf("user groups cannot be empty")
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func isProtocolValid(exposeProtocol string) bool {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http", "https", "tcp", "udp", "tls":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
log.Errorf("failed initializing log %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = false
|
||||
|
||||
port, err := validateExposeFlags(cmd, args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = true
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
cancel()
|
||||
}()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close daemon connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
protocol, err := toExposeProtocol(exposeProtocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &proto.ExposeServiceRequest{
|
||||
Port: uint32(port),
|
||||
Protocol: protocol,
|
||||
Pin: exposePin,
|
||||
Password: exposePassword,
|
||||
UserGroups: exposeUserGroups,
|
||||
Domain: exposeDomain,
|
||||
NamePrefix: exposeNamePrefix,
|
||||
}
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
req.ListenPort = uint32(resolveExternalPort(port))
|
||||
}
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return waitForExposeEvents(cmd, ctx, stream)
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
p, err := expose.ParseProtocolType(exposeProtocol)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
switch p {
|
||||
case expose.ProtocolHTTP:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case expose.ProtocolHTTPS:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
case expose.ProtocolTCP:
|
||||
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||
case expose.ProtocolUDP:
|
||||
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||
case expose.ProtocolTLS:
|
||||
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unhandled protocol type: %d", p)
|
||||
}
|
||||
}
|
||||
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||
}
|
||||
printExposeReady(cmd, ready.Ready, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
|
||||
cmd.Println("Service exposed successfully!")
|
||||
cmd.Printf(" Name: %s\n", r.ServiceName)
|
||||
if r.ServiceUrl != "" {
|
||||
cmd.Printf(" URL: %s\n", r.ServiceUrl)
|
||||
}
|
||||
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
|
||||
cmd.Printf(" Domain: %s\n", r.Domain)
|
||||
}
|
||||
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||
cmd.Printf(" Internal: %d\n", port)
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
|
||||
}
|
||||
if r.PortAutoAssigned && exposeExternalPort != 0 {
|
||||
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
|
||||
}
|
||||
cmd.Println()
|
||||
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||
}
|
||||
|
||||
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||
for {
|
||||
_, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.Println("\nService stopped.")
|
||||
//nolint:nilerr
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -81,15 +80,6 @@ var (
|
||||
Short: "",
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(cmd.Root())
|
||||
|
||||
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||
if !isServiceCmd(cmd) {
|
||||
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,7 +144,6 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
rootCmd.AddCommand(exposeCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
@@ -396,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
@@ -408,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||
func isServiceCmd(cmd *cobra.Command) bool {
|
||||
for c := cmd; c != nil; c = c.Parent() {
|
||||
if c.Name() == "service" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func init() {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
|
||||
// Common setup for service control commands
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -119,10 +119,6 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -140,10 +136,6 @@ var installCmd = &cobra.Command{
|
||||
return fmt.Errorf("install service: %w", err)
|
||||
}
|
||||
|
||||
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||
}
|
||||
|
||||
cmd.Println("NetBird service has been installed")
|
||||
return nil
|
||||
},
|
||||
@@ -195,10 +187,6 @@ This command will temporarily stop the service, update its configuration, and re
|
||||
return err
|
||||
}
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
wasRunning, err := isServiceRunning()
|
||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||
return fmt.Errorf("check service status: %w", err)
|
||||
@@ -234,10 +222,6 @@ This command will temporarily stop the service, update its configuration, and re
|
||||
return fmt.Errorf("install service with new config: %w", err)
|
||||
}
|
||||
|
||||
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||
}
|
||||
|
||||
if wasRunning {
|
||||
cmd.Println("Starting NetBird service...")
|
||||
if err := s.Start(); err != nil {
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const serviceParamsFile = "service.json"
|
||||
|
||||
// serviceParams holds install-time service parameters that persist across
|
||||
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
|
||||
type serviceParams struct {
|
||||
LogLevel string `json:"log_level"`
|
||||
DaemonAddr string `json:"daemon_addr"`
|
||||
ManagementURL string `json:"management_url,omitempty"`
|
||||
ConfigPath string `json:"config_path,omitempty"`
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
// serviceParamsPath returns the path to the service params file.
|
||||
func serviceParamsPath() string {
|
||||
return filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
}
|
||||
|
||||
// loadServiceParams reads saved service parameters from disk.
|
||||
// Returns nil with no error if the file does not exist.
|
||||
func loadServiceParams() (*serviceParams, error) {
|
||||
path := serviceParamsPath()
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return nil, fmt.Errorf("read service params %s: %w", path, err)
|
||||
}
|
||||
|
||||
var params serviceParams
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("parse service params %s: %w", path, err)
|
||||
}
|
||||
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
// saveServiceParams writes current service parameters to disk atomically
|
||||
// with restricted permissions.
|
||||
func saveServiceParams(params *serviceParams) error {
|
||||
path := serviceParamsPath()
|
||||
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
|
||||
return fmt.Errorf("save service params: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentServiceParams captures the current state of all package-level
|
||||
// variables into a serviceParams struct.
|
||||
func currentServiceParams() *serviceParams {
|
||||
params := &serviceParams{
|
||||
LogLevel: logLevel,
|
||||
DaemonAddr: daemonAddr,
|
||||
ManagementURL: managementURL,
|
||||
ConfigPath: configPath,
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err == nil && len(parsed) > 0 {
|
||||
params.ServiceEnvVars = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// loadAndApplyServiceParams loads saved params from disk and applies them
|
||||
// to any flags that were not explicitly set.
|
||||
func loadAndApplyServiceParams(cmd *cobra.Command) error {
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
applyServiceParams(cmd, params)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyServiceParams merges saved parameters into package-level variables
|
||||
// for any flag that was not explicitly set by the user (via CLI or env var).
|
||||
// Flags that were Changed() are left untouched.
|
||||
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
||||
// != "" guard so that an older service.json missing the field doesn't
|
||||
// clobber the default with an empty string.
|
||||
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||
logLevel = params.LogLevel
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
|
||||
daemonAddr = params.DaemonAddr
|
||||
}
|
||||
|
||||
// For optional fields where empty means "use default", always apply so
|
||||
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||
managementURL = params.ManagementURL
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("config") {
|
||||
configPath = params.ConfigPath
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("log-file") {
|
||||
logFiles = params.LogFiles
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
|
||||
profilesDisabled = params.DisableProfiles
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
// applyServiceEnvParams merges saved service environment variables.
|
||||
// If --service-env was explicitly set, explicit values win on key conflict
|
||||
// but saved keys not in the explicit set are carried over.
|
||||
// If --service-env was not set, saved env vars are used entirely.
|
||||
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if len(params.ServiceEnvVars) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !cmd.Flags().Changed("service-env") {
|
||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||
return
|
||||
}
|
||||
|
||||
// Explicit env vars were provided: merge saved values underneath.
|
||||
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||
maps.Copy(merged, params.ServiceEnvVars)
|
||||
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||
serviceEnvVars = envMapToSlice(merged)
|
||||
}
|
||||
|
||||
var resetParamsCmd = &cobra.Command{
|
||||
Use: "reset-params",
|
||||
Short: "Remove saved service install parameters",
|
||||
Long: "Removes the saved service.json file so the next install uses default parameters.",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
path := serviceParamsPath()
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
cmd.Println("No saved service parameters found")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("remove service params: %w", err)
|
||||
}
|
||||
cmd.Printf("Removed saved service parameters (%s)\n", path)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
|
||||
func envMapToSlice(m map[string]string) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k, v := range m {
|
||||
s = append(s, k+"="+v)
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,523 +0,0 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
)
|
||||
|
||||
func TestServiceParamsPath(t *testing.T) {
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
|
||||
configs.StateDir = "/var/lib/netbird"
|
||||
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||
|
||||
configs.StateDir = "/custom/state"
|
||||
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||
}
|
||||
|
||||
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
params := &serviceParams{
|
||||
LogLevel: "debug",
|
||||
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||
ManagementURL: "https://my.server.com",
|
||||
ConfigPath: "/etc/netbird/config.json",
|
||||
LogFiles: []string{"/var/log/netbird/client.log", "console"},
|
||||
DisableProfiles: true,
|
||||
DisableUpdateSettings: false,
|
||||
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
|
||||
}
|
||||
|
||||
err := saveServiceParams(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and is valid JSON.
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, json.Valid(data))
|
||||
|
||||
loaded, err := loadServiceParams()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loaded)
|
||||
|
||||
assert.Equal(t, params.LogLevel, loaded.LogLevel)
|
||||
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
|
||||
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
|
||||
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
|
||||
assert.Equal(t, params.LogFiles, loaded.LogFiles)
|
||||
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
|
||||
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
|
||||
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
|
||||
}
|
||||
|
||||
func TestLoadServiceParams_FileNotExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
params, err := loadServiceParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, params)
|
||||
}
|
||||
|
||||
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
params, err := loadServiceParams()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, params)
|
||||
}
|
||||
|
||||
func TestCurrentServiceParams(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
origDaemonAddr := daemonAddr
|
||||
origManagementURL := managementURL
|
||||
origConfigPath := configPath
|
||||
origLogFiles := logFiles
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() {
|
||||
logLevel = origLogLevel
|
||||
daemonAddr = origDaemonAddr
|
||||
managementURL = origManagementURL
|
||||
configPath = origConfigPath
|
||||
logFiles = origLogFiles
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
serviceEnvVars = origServiceEnvVars
|
||||
})
|
||||
|
||||
logLevel = "trace"
|
||||
daemonAddr = "tcp://127.0.0.1:9999"
|
||||
managementURL = "https://mgmt.example.com"
|
||||
configPath = "/tmp/test-config.json"
|
||||
logFiles = []string{"/tmp/test.log"}
|
||||
profilesDisabled = true
|
||||
updateSettingsDisabled = true
|
||||
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
|
||||
|
||||
params := currentServiceParams()
|
||||
|
||||
assert.Equal(t, "trace", params.LogLevel)
|
||||
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
|
||||
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
|
||||
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
|
||||
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
|
||||
assert.True(t, params.DisableProfiles)
|
||||
assert.True(t, params.DisableUpdateSettings)
|
||||
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
origDaemonAddr := daemonAddr
|
||||
origManagementURL := managementURL
|
||||
origConfigPath := configPath
|
||||
origLogFiles := logFiles
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() {
|
||||
logLevel = origLogLevel
|
||||
daemonAddr = origDaemonAddr
|
||||
managementURL = origManagementURL
|
||||
configPath = origConfigPath
|
||||
logFiles = origLogFiles
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
serviceEnvVars = origServiceEnvVars
|
||||
})
|
||||
|
||||
// Reset all flags to defaults.
|
||||
logLevel = "info"
|
||||
daemonAddr = "unix:///var/run/netbird.sock"
|
||||
managementURL = ""
|
||||
configPath = "/etc/netbird/config.json"
|
||||
logFiles = []string{"/var/log/netbird/client.log"}
|
||||
profilesDisabled = false
|
||||
updateSettingsDisabled = false
|
||||
serviceEnvVars = nil
|
||||
|
||||
// Reset Changed state on all relevant flags.
|
||||
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
// Simulate user explicitly setting --log-level via CLI.
|
||||
logLevel = "warn"
|
||||
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
|
||||
|
||||
saved := &serviceParams{
|
||||
LogLevel: "debug",
|
||||
DaemonAddr: "tcp://127.0.0.1:5555",
|
||||
ManagementURL: "https://saved.example.com",
|
||||
ConfigPath: "/saved/config.json",
|
||||
LogFiles: []string{"/saved/client.log"},
|
||||
DisableProfiles: true,
|
||||
DisableUpdateSettings: true,
|
||||
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
// log-level was Changed, so it should keep "warn", not use saved "debug".
|
||||
assert.Equal(t, "warn", logLevel)
|
||||
|
||||
// All other fields were not Changed, so they should use saved values.
|
||||
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
|
||||
assert.Equal(t, "https://saved.example.com", managementURL)
|
||||
assert.Equal(t, "/saved/config.json", configPath)
|
||||
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
|
||||
assert.True(t, profilesDisabled)
|
||||
assert.True(t, updateSettingsDisabled)
|
||||
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
t.Cleanup(func() {
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
})
|
||||
|
||||
// Simulate current state where booleans are true (e.g. set by previous install).
|
||||
profilesDisabled = true
|
||||
updateSettingsDisabled = true
|
||||
|
||||
// Reset Changed state so flags appear unset.
|
||||
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
// Saved params have both as false.
|
||||
saved := &serviceParams{
|
||||
DisableProfiles: false,
|
||||
DisableUpdateSettings: false,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
assert.False(t, profilesDisabled, "saved false should override current true")
|
||||
assert.False(t, updateSettingsDisabled, "saved false should override current true")
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
|
||||
origManagementURL := managementURL
|
||||
t.Cleanup(func() { managementURL = origManagementURL })
|
||||
|
||||
managementURL = "https://leftover.example.com"
|
||||
|
||||
// Simulate saved params where management URL was explicitly cleared.
|
||||
saved := &serviceParams{
|
||||
LogLevel: "info",
|
||||
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||
// ManagementURL intentionally empty: was cleared with --management-url "".
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_NilParams(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
t.Cleanup(func() { logLevel = origLogLevel })
|
||||
|
||||
logLevel = "info"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
|
||||
// Should be a no-op.
|
||||
applyServiceParams(cmd, nil)
|
||||
assert.Equal(t, "info", logLevel)
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Set up a command with --service-env marked as Changed.
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
|
||||
|
||||
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{
|
||||
"SAVED": "val",
|
||||
"OVERLAP": "saved",
|
||||
},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
// Parse result for easier assertion.
|
||||
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "yes", result["EXPLICIT"])
|
||||
assert.Equal(t, "val", result["SAVED"])
|
||||
// Explicit wins on conflict.
|
||||
assert.Equal(t, "explicit", result["OVERLAP"])
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
serviceEnvVars = nil
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||
}
|
||||
|
||||
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||
// added to serviceParams but not wired into these functions, this test fails.
|
||||
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Collect all JSON field names from the serviceParams struct.
|
||||
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
|
||||
|
||||
// Collect field names referenced in currentServiceParams and applyServiceParams.
|
||||
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
|
||||
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
|
||||
// applyServiceEnvParams handles ServiceEnvVars indirectly.
|
||||
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
|
||||
for k, v := range applyEnvFields {
|
||||
applyFields[k] = v
|
||||
}
|
||||
|
||||
for _, field := range structFields {
|
||||
assert.Contains(t, currentFields, field,
|
||||
"serviceParams field %q is not captured in currentServiceParams()", field)
|
||||
assert.Contains(t, applyFields, field,
|
||||
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
|
||||
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
|
||||
// it flows through newSVCConfig() EnvVars, not CLI args.
|
||||
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||
require.NotEmpty(t, structFields)
|
||||
|
||||
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
|
||||
fieldsNotInArgs := map[string]bool{
|
||||
"ServiceEnvVars": true,
|
||||
}
|
||||
|
||||
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
|
||||
|
||||
// Forward: every struct field must appear in buildServiceArguments.
|
||||
for _, field := range structFields {
|
||||
if fieldsNotInArgs[field] {
|
||||
continue
|
||||
}
|
||||
globalVar := fieldToGlobalVar(field)
|
||||
assert.Contains(t, buildFields, globalVar,
|
||||
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
|
||||
}
|
||||
|
||||
// Reverse: every service-related global used in buildServiceArguments must
|
||||
// have a corresponding serviceParams field. This catches a developer adding
|
||||
// a new flag to buildServiceArguments without adding it to the struct.
|
||||
globalToField := make(map[string]string, len(structFields))
|
||||
for _, field := range structFields {
|
||||
globalToField[fieldToGlobalVar(field)] = field
|
||||
}
|
||||
// Identifiers in buildServiceArguments that are not service params
|
||||
// (builtins, boilerplate, loop variables).
|
||||
nonParamGlobals := map[string]bool{
|
||||
"args": true, "append": true, "string": true, "_": true,
|
||||
"logFile": true, // range variable over logFiles
|
||||
}
|
||||
for ref := range buildFields {
|
||||
if nonParamGlobals[ref] {
|
||||
continue
|
||||
}
|
||||
_, inStruct := globalToField[ref]
|
||||
assert.True(t, inStruct,
|
||||
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
|
||||
}
|
||||
}
|
||||
|
||||
// extractStructJSONFields returns field names from a named struct type.
|
||||
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
|
||||
t.Helper()
|
||||
var fields []string
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
ts, ok := n.(*ast.TypeSpec)
|
||||
if !ok || ts.Name.Name != structName {
|
||||
return true
|
||||
}
|
||||
st, ok := ts.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, f := range st.Fields.List {
|
||||
if len(f.Names) > 0 {
|
||||
fields = append(fields, f.Names[0].Name)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return fields
|
||||
}
|
||||
|
||||
// extractFuncFieldRefs returns which of the given field names appear inside the
|
||||
// named function, either as selector expressions (params.FieldName) or as
|
||||
// composite literal keys (&serviceParams{FieldName: ...}).
|
||||
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
|
||||
t.Helper()
|
||||
fieldSet := make(map[string]bool, len(fields))
|
||||
for _, f := range fields {
|
||||
fieldSet[f] = true
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
fn := findFuncDecl(file, funcName)
|
||||
require.NotNil(t, fn, "function %s not found", funcName)
|
||||
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
switch v := n.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
if fieldSet[v.Sel.Name] {
|
||||
found[v.Sel.Name] = true
|
||||
}
|
||||
case *ast.KeyValueExpr:
|
||||
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
|
||||
found[ident.Name] = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
|
||||
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
|
||||
t.Helper()
|
||||
fn := findFuncDecl(file, funcName)
|
||||
require.NotNil(t, fn, "function %s not found", funcName)
|
||||
|
||||
refs := make(map[string]bool)
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
if ident, ok := n.(*ast.Ident); ok {
|
||||
refs[ident.Name] = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
return refs
|
||||
}
|
||||
|
||||
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
|
||||
for _, decl := range file.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if ok && fn.Name.Name == name {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fieldToGlobalVar maps serviceParams field names to the package-level variable
|
||||
// names used in buildServiceArguments and applyServiceParams.
|
||||
func fieldToGlobalVar(field string) string {
|
||||
m := map[string]string{
|
||||
"LogLevel": "logLevel",
|
||||
"DaemonAddr": "daemonAddr",
|
||||
"ManagementURL": "managementURL",
|
||||
"ConfigPath": "configPath",
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
if v, ok := m[field]; ok {
|
||||
return v
|
||||
}
|
||||
// Default: lowercase first letter.
|
||||
return strings.ToLower(field[:1]) + field[1:]
|
||||
}
|
||||
|
||||
func TestEnvMapToSlice(t *testing.T) {
|
||||
m := map[string]string{"A": "1", "B": "2"}
|
||||
s := envMapToSlice(m)
|
||||
assert.Len(t, s, 2)
|
||||
assert.Contains(t, s, "A=1")
|
||||
assert.Contains(t, s, "B=2")
|
||||
}
|
||||
|
||||
func TestEnvMapToSlice_Empty(t *testing.T) {
|
||||
s := envMapToSlice(map[string]string{})
|
||||
assert.Empty(t, s)
|
||||
}
|
||||
@@ -4,9 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,22 +13,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||
<-sig
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
@@ -97,34 +79,6 @@ func TestServiceLifecycle(t *testing.T) {
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -28,7 +28,6 @@ var (
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
connectionTypeFilter string
|
||||
checkFlag string
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -50,7 +49,6 @@ func init() {
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -58,10 +56,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
if checkFlag != "" {
|
||||
return runHealthCheck(cmd)
|
||||
}
|
||||
|
||||
err := parseFilters()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -74,17 +68,15 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx, true, false)
|
||||
resp, err := getStatus(ctx, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status := resp.GetStatus()
|
||||
|
||||
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired)
|
||||
|
||||
if needsAuth && !jsonFlag && !yamlFlag {
|
||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired) {
|
||||
cmd.Printf("Daemon status: %s\n\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
@@ -107,17 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
DaemonStatus: nbstatus.ParseDaemonStatus(status),
|
||||
StatusFilter: statusFilter,
|
||||
PrefixNamesFilter: prefixNamesFilter,
|
||||
PrefixNamesFilterMap: prefixNamesFilterMap,
|
||||
IPsFilter: ipsFilterMap,
|
||||
ConnectionTypeFilter: connectionTypeFilter,
|
||||
ProfileName: profName,
|
||||
})
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
@@ -139,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
@@ -149,7 +131,7 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -203,83 +185,6 @@ func enableDetailFlagWhenFilterFlag() {
|
||||
}
|
||||
}
|
||||
|
||||
func runHealthCheck(cmd *cobra.Command) error {
|
||||
check := strings.ToLower(checkFlag)
|
||||
switch check {
|
||||
case "live", "ready", "startup":
|
||||
default:
|
||||
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
|
||||
}
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
isStartup := check == "startup"
|
||||
resp, err := getStatus(ctx, isStartup, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch check {
|
||||
case "live":
|
||||
return nil
|
||||
case "ready":
|
||||
return checkReadiness(resp)
|
||||
case "startup":
|
||||
return checkStartup(resp)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkReadiness(resp *proto.StatusResponse) error {
|
||||
daemonStatus := internal.StatusType(resp.GetStatus())
|
||||
switch daemonStatus {
|
||||
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
|
||||
return nil
|
||||
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
|
||||
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
|
||||
default:
|
||||
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func checkStartup(resp *proto.StatusResponse) error {
|
||||
fullStatus := resp.GetFullStatus()
|
||||
if fullStatus == nil {
|
||||
return fmt.Errorf("startup check: no full status available")
|
||||
}
|
||||
|
||||
if !fullStatus.GetManagementState().GetConnected() {
|
||||
return fmt.Errorf("startup check: management not connected")
|
||||
}
|
||||
|
||||
if !fullStatus.GetSignalState().GetConnected() {
|
||||
return fmt.Errorf("startup check: signal not connected")
|
||||
}
|
||||
|
||||
var relayCount, relaysConnected int
|
||||
for _, r := range fullStatus.GetRelays() {
|
||||
uri := r.GetURI()
|
||||
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
|
||||
continue
|
||||
}
|
||||
relayCount++
|
||||
if r.GetAvailable() {
|
||||
relaysConnected++
|
||||
}
|
||||
}
|
||||
|
||||
if relayCount > 0 && relaysConnected == 0 {
|
||||
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseInterfaceIP(interfaceIP string) string {
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
|
||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -33,14 +31,14 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
const (
|
||||
// PeerStatusConnected indicates the peer is in connected state.
|
||||
PeerStatusConnected = peer.StatusConnected
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -83,14 +81,6 @@ type Options struct {
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the WireGuard interface.
|
||||
// Valid values are in the range 576..8192 bytes.
|
||||
// If non-nil, this value overrides any value stored in the config file.
|
||||
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -122,12 +112,6 @@ func New(opts Options) (*Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.MTU != nil {
|
||||
if err := iface.ValidateMTU(*opts.MTU); err != nil {
|
||||
return nil, fmt.Errorf("invalid MTU: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
@@ -156,14 +140,9 @@ func New(opts Options) (*Client, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
var parsedLabels domain.List
|
||||
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
|
||||
return nil, fmt.Errorf("invalid dns labels: %w", err)
|
||||
}
|
||||
|
||||
t := true
|
||||
var config *profilemanager.Config
|
||||
var err error
|
||||
input := profilemanager.ConfigInput{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
ManagementURL: opts.ManagementURL,
|
||||
@@ -172,8 +151,6 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
DNSLabels: parsedLabels,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
@@ -225,7 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
@@ -375,32 +352,6 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
|
||||
// It returns an ExposeSession. Call Wait on the session to keep it alive.
|
||||
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return nil, fmt.Errorf("expose manager not available")
|
||||
}
|
||||
|
||||
resp, err := mgr.Expose(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expose: %w", err)
|
||||
}
|
||||
|
||||
return &ExposeSession{
|
||||
Domain: resp.Domain,
|
||||
ServiceName: resp.ServiceName,
|
||||
ServiceURL: resp.ServiceURL,
|
||||
mgr: mgr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
)
|
||||
|
||||
const (
|
||||
// ExposeProtocolHTTP exposes the service as HTTP.
|
||||
ExposeProtocolHTTP = expose.ProtocolHTTP
|
||||
// ExposeProtocolHTTPS exposes the service as HTTPS.
|
||||
ExposeProtocolHTTPS = expose.ProtocolHTTPS
|
||||
// ExposeProtocolTCP exposes the service as TCP.
|
||||
ExposeProtocolTCP = expose.ProtocolTCP
|
||||
// ExposeProtocolUDP exposes the service as UDP.
|
||||
ExposeProtocolUDP = expose.ProtocolUDP
|
||||
// ExposeProtocolTLS exposes the service as TLS.
|
||||
ExposeProtocolTLS = expose.ProtocolTLS
|
||||
)
|
||||
|
||||
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
|
||||
type ExposeRequest = expose.Request
|
||||
|
||||
// ExposeProtocolType represents the protocol used for exposing a service.
|
||||
type ExposeProtocolType = expose.ProtocolType
|
||||
|
||||
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
|
||||
type ExposeSession struct {
|
||||
Domain string
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
|
||||
mgr *expose.Manager
|
||||
}
|
||||
|
||||
// Wait blocks while keeping the expose session alive.
|
||||
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
|
||||
func (s *ExposeSession) Wait(ctx context.Context) error {
|
||||
if s == nil || s.mgr == nil {
|
||||
return errors.New("expose session is not initialized")
|
||||
}
|
||||
return s.mgr.KeepAlive(ctx, s.Domain)
|
||||
}
|
||||
@@ -23,10 +23,9 @@ type Manager struct {
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
rawSupported bool
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -85,7 +84,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
@@ -286,22 +285,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
@@ -335,10 +318,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !m.rawSupported {
|
||||
return fmt.Errorf("raw table not available")
|
||||
}
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
@@ -396,16 +375,12 @@ func (m *Manager) initNoTrackChain() error {
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
m.rawSupported = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
if !m.rawSupported {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
@@ -426,7 +401,6 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
m.rawSupported = false
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ const (
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
@@ -44,7 +43,6 @@ const (
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
@@ -389,14 +387,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
@@ -406,7 +396,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
@@ -981,81 +970,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -169,14 +169,6 @@ type Manager interface {
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
@@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
@@ -346,22 +346,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
|
||||
@@ -36,7 +36,6 @@ const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
@@ -1854,130 +1853,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -140,17 +140,6 @@ type Manager struct {
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
tcpHookOut atomic.Pointer[packetHook]
|
||||
}
|
||||
|
||||
// packetHook stores a registered hook for a specific IP:port.
|
||||
type packetHook struct {
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
fn func([]byte) bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -605,8 +594,6 @@ func (m *Manager) resetState() {
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -726,9 +713,6 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
@@ -911,21 +895,38 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
if h.ip == dstIP && h.port == dport {
|
||||
return h.fn(packetData)
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1277,6 +1278,12 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
@@ -1335,30 +1342,65 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.udpHookOut.Store(nil)
|
||||
return
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
udpHook: hook,
|
||||
}
|
||||
m.udpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
}
|
||||
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.tcpHookOut.Store(nil)
|
||||
return
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Check incoming hooks (stored in allow rules)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
m.tcpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
// Check outgoing hooks
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@@ -187,52 +186,81 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
}
|
||||
|
||||
var called bool
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(8000), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load())
|
||||
}
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
var called bool
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(53), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
assert.Nil(t, manager.tcpHookOut.Load())
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
@@ -502,12 +530,39 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
|
||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||
if !found {
|
||||
t.Fatalf("The hook was not added properly.")
|
||||
}
|
||||
|
||||
// Now remove the packet hook
|
||||
err = manager.RemovePacketHook(hookID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove hook: %s", err)
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
@@ -537,7 +592,8 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
manager.SetUDPPacketHook(
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
@@ -545,6 +601,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
|
||||
@@ -144,8 +144,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
|
||||
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
@@ -421,7 +421,6 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -467,22 +466,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
|
||||
@@ -18,7 +18,9 @@ type PeerRule struct {
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
drop bool
|
||||
|
||||
udpHook func([]byte) bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -399,17 +399,21 @@ func TestTracePacket(t *testing.T) {
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||
return true // drop (intercepted by hook)
|
||||
})
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
|
||||
@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
// for js, the outer websocket layer takes care of tls
|
||||
if tlsEnabled && runtime.GOOS != "js" {
|
||||
@@ -46,7 +46,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
opts := []grpc.DialOption{
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
@@ -54,10 +56,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
}
|
||||
opts = append(opts, extraOpts...)
|
||||
|
||||
conn, err := grpc.DialContext(connCtx, addr, opts...)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
@@ -5,18 +5,20 @@ package configurer
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||
if err != nil {
|
||||
_ = uapiSock.Close()
|
||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -54,14 +54,6 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
return &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
|
||||
@@ -15,17 +15,14 @@ type PacketFilter interface {
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
if cErr := tunIface.Close(); cErr != nil {
|
||||
|
||||
@@ -34,28 +34,18 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// SetUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
@@ -85,3 +75,17 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
87
client/iface/mocks/iface/mocks/filter.go
Normal file
87
client/iface/mocks/iface/mocks/filter.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketFilter is a mock of PacketFilter interface.
|
||||
type MockPacketFilter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketFilterMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||
type MockPacketFilterMockRecorder struct {
|
||||
mock *MockPacketFilter
|
||||
}
|
||||
|
||||
// NewMockPacketFilter creates a new mock instance.
|
||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||
mock := &MockPacketFilter{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
@@ -201,7 +201,13 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
@@ -215,7 +221,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
@@ -240,7 +246,13 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
@@ -254,7 +266,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
@@ -280,16 +292,28 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return err
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
@@ -298,7 +322,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -23,13 +23,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -44,19 +43,14 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
doInitialAutoUpdate bool
|
||||
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
clientMetrics *metrics.ClientMetrics
|
||||
updateManager *updater.Manager
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
|
||||
persistSyncResponse bool
|
||||
}
|
||||
@@ -65,24 +59,19 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
c.updateManager = um
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
@@ -111,7 +100,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -121,7 +109,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
@@ -144,34 +131,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
}()
|
||||
|
||||
// Stop metrics push on exit
|
||||
defer func() {
|
||||
if c.clientMetrics != nil {
|
||||
c.clientMetrics.StopPush()
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
nbnet.Init()
|
||||
|
||||
// Initialize metrics once at startup (always active for debug bundles)
|
||||
if c.clientMetrics == nil {
|
||||
agentInfo := metrics.AgentInfo{
|
||||
DeploymentType: metrics.DeploymentTypeUnknown,
|
||||
Version: version.NetbirdVersion(),
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
}
|
||||
c.clientMetrics = metrics.NewClientMetrics(agentInfo)
|
||||
log.Debugf("initialized client metrics")
|
||||
|
||||
// Start metrics push if enabled (uses daemon context, persists across engine restarts)
|
||||
if metrics.IsMetricsPushEnabled() {
|
||||
c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv())
|
||||
}
|
||||
}
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -224,13 +187,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
stateManager := statemanager.New(path)
|
||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
if c.updateManager != nil {
|
||||
c.updateManager.CheckUpdateSuccess(c.ctx)
|
||||
}
|
||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||
if err == nil {
|
||||
updateManager.CheckUpdateSuccess(c.ctx)
|
||||
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
@@ -258,16 +222,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||
mgmClient.SetConnStateListener(mgmNotifier)
|
||||
|
||||
// Update metrics with actual deployment type after connection
|
||||
deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL())
|
||||
agentInfo := metrics.AgentInfo{
|
||||
DeploymentType: deploymentType,
|
||||
Version: version.NetbirdVersion(),
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
}
|
||||
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
|
||||
|
||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||
defer func() {
|
||||
if err = mgmClient.Close(); err != nil {
|
||||
@@ -276,10 +230,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}()
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginStarted := time.Now()
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
if err != nil {
|
||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
@@ -288,7 +240,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
return wrapErr(err)
|
||||
}
|
||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
|
||||
c.statusRecorder.MarkManagementConnected()
|
||||
|
||||
localPeerState := peer.LocalPeerState{
|
||||
@@ -357,16 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmClient,
|
||||
RelayManager: relayManager,
|
||||
StatusRecorder: c.statusRecorder,
|
||||
Checks: checks,
|
||||
StateManager: stateManager,
|
||||
UpdateManager: c.updateManager,
|
||||
ClientMetrics: c.clientMetrics,
|
||||
}, mobileDependency)
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
@@ -376,15 +318,21 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||
if c.doInitialAutoUpdate {
|
||||
log.Infof("start engine by ui, run auto-update check")
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
c.doInitialAutoUpdate = false
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
@@ -619,6 +567,12 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
@@ -637,7 +591,12 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
||||
// It returns an empty interface list, which means ICE P2P candidates won't be
|
||||
// discovered — connections will fall back to relay. Applications that need P2P
|
||||
// should provide a real implementation via runOnAndroidEmbed that uses
|
||||
// Android's ConnectivityManager to enumerate network interfaces.
|
||||
type noopIFaceDiscover struct{}
|
||||
|
||||
func (noopIFaceDiscover) IFaces() (string, error) {
|
||||
// Return empty JSON array — no local interfaces advertised for ICE.
|
||||
// This is intentional: without Android's ConnectivityManager, we cannot
|
||||
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
||||
// Relay connections still work; only P2P hole-punching is disabled.
|
||||
return "[]", nil
|
||||
}
|
||||
|
||||
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
||||
// Network change events are ignored since the embed client manages its own
|
||||
// reconnection logic via the engine's built-in retry mechanism.
|
||||
type noopNetworkChangeListener struct{}
|
||||
|
||||
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
||||
// No-op: embed.Client relies on the engine's internal reconnection
|
||||
// logic rather than OS-level network change notifications.
|
||||
}
|
||||
|
||||
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
||||
// network stack, not by OS-level interface configuration.
|
||||
}
|
||||
|
||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||
// DNS readiness notifications are not needed in netstack/embed mode
|
||||
// since system DNS is disabled and DNS resolution happens externally.
|
||||
type noopDnsReadyListener struct{}
|
||||
|
||||
func (noopDnsReadyListener) OnReady() {
|
||||
// No-op: embed.Client does not need DNS readiness notifications.
|
||||
// System DNS is disabled in netstack mode.
|
||||
}
|
||||
|
||||
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
||||
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var scanDir = "/var/run/netbird"
|
||||
|
||||
// setScanDir overrides the scan directory (used by tests).
|
||||
func setScanDir(dir string) {
|
||||
scanDir = dir
|
||||
}
|
||||
|
||||
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||
// mismatch between the netbird@.service template (which places the socket under
|
||||
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
if !strings.HasPrefix(addr, "unix://") {
|
||||
return addr
|
||||
}
|
||||
|
||||
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||
if _, err := os.Stat(sockPath); err == nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(scanDir)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
var found []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(e.Name(), ".sock") {
|
||||
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
switch len(found) {
|
||||
case 1:
|
||||
resolved := "unix://" + found[0]
|
||||
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||
return resolved
|
||||
case 0:
|
||||
return addr
|
||||
default:
|
||||
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||
return addr
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build windows || ios || android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
return addr
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createSockFile creates a regular file with a .sock extension.
|
||||
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||
func createSockFile(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
sock := filepath.Join(tmp, "netbird.sock")
|
||||
createSockFile(t, sock)
|
||||
|
||||
addr := "unix://" + sock
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
// Default socket does not exist
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
// Create a scan dir with one socket
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
instanceSock := filepath.Join(sd, "main.sock")
|
||||
createSockFile(t, instanceSock)
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
expected := "unix://" + instanceSock
|
||||
if got != expected {
|
||||
t.Errorf("expected %s, got %s", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||
addr := "tcp://127.0.0.1:41731"
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
@@ -25,13 +25,13 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -53,8 +53,6 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
||||
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
block.prof: Block profiling information.
|
||||
@@ -221,11 +219,6 @@ const (
|
||||
darwinStdoutLogPath = "/var/log/netbird.err.log"
|
||||
)
|
||||
|
||||
// MetricsExporter is an interface for exporting metrics
|
||||
type MetricsExporter interface {
|
||||
Export(w io.Writer) error
|
||||
}
|
||||
|
||||
type BundleGenerator struct {
|
||||
anonymizer *anonymize.Anonymizer
|
||||
|
||||
@@ -236,7 +229,6 @@ type BundleGenerator struct {
|
||||
logPath string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
|
||||
anonymize bool
|
||||
includeSystemInfo bool
|
||||
@@ -258,7 +250,6 @@ type GeneratorDependencies struct {
|
||||
LogPath string
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
ClientMetrics MetricsExporter
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -277,7 +268,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
logPath: deps.LogPath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
@@ -361,14 +351,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addServiceParams(); err != nil {
|
||||
log.Errorf("failed to add service params to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addMetrics(); err != nil {
|
||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addWgShow(); err != nil {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
@@ -436,10 +418,7 @@ func (g *BundleGenerator) addStatus() error {
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
|
||||
Anonymize: g.anonymize,
|
||||
ProfileName: profName,
|
||||
})
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
||||
statusOutput := overview.FullDetailSummary()
|
||||
|
||||
statusReader := strings.NewReader(statusOutput)
|
||||
@@ -494,90 +473,6 @@ func (g *BundleGenerator) addConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
serviceParamsFile = "service.json"
|
||||
serviceParamsBundle = "service_params.json"
|
||||
maskedValue = "***"
|
||||
envVarPrefix = "NB_"
|
||||
jsonKeyManagementURL = "management_url"
|
||||
jsonKeyServiceEnv = "service_env_vars"
|
||||
)
|
||||
|
||||
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
||||
|
||||
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
||||
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
||||
func (g *BundleGenerator) addServiceParams() error {
|
||||
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read service params: %w", err)
|
||||
}
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return fmt.Errorf("parse service params: %w", err)
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
||||
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
||||
}
|
||||
}
|
||||
|
||||
g.sanitizeServiceEnvVars(params)
|
||||
|
||||
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sanitized service params: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
||||
return fmt.Errorf("add service params to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
||||
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
||||
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
||||
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
||||
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := make(map[string]any, len(envVars))
|
||||
for k, v := range envVars {
|
||||
val, _ := v.(string)
|
||||
switch {
|
||||
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
||||
sanitized[k] = maskedValue
|
||||
case g.anonymize:
|
||||
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
||||
default:
|
||||
sanitized[k] = val
|
||||
}
|
||||
}
|
||||
params[jsonKeyServiceEnv] = sanitized
|
||||
}
|
||||
|
||||
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
||||
func isSensitiveEnvVar(key string) bool {
|
||||
lower := strings.ToLower(key)
|
||||
for _, s := range sensitiveEnvSubstrings {
|
||||
if strings.Contains(lower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
@@ -849,30 +744,6 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addMetrics() error {
|
||||
if g.clientMetrics == nil {
|
||||
log.Debugf("skipping metrics in debug bundle: no metrics collector")
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := g.clientMetrics.Export(&buf); err != nil {
|
||||
return fmt.Errorf("export metrics: %w", err)
|
||||
}
|
||||
|
||||
if buf.Len() == 0 {
|
||||
log.Debugf("skipping metrics.txt in debug bundle: no metrics data")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
|
||||
return fmt.Errorf("add metrics file to zip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("added metrics to debug bundle")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addLogfile() error {
|
||||
if g.logPath == "" {
|
||||
log.Debugf("skipping empty log file in debug bundle")
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -14,7 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -425,226 +420,6 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveEnvVar(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
sensitive bool
|
||||
}{
|
||||
{"NB_SETUP_KEY", true},
|
||||
{"NB_API_TOKEN", true},
|
||||
{"NB_CLIENT_SECRET", true},
|
||||
{"NB_PASSWORD", true},
|
||||
{"NB_CREDENTIAL", true},
|
||||
{"NB_LOG_LEVEL", false},
|
||||
{"NB_MANAGEMENT_URL", false},
|
||||
{"NB_HOSTNAME", false},
|
||||
{"HOME", false},
|
||||
{"PATH", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anonymize bool
|
||||
input map[string]any
|
||||
check func(t *testing.T, params map[string]any)
|
||||
}{
|
||||
{
|
||||
name: "no env vars key",
|
||||
anonymize: false,
|
||||
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
||||
_, ok := params[jsonKeyServiceEnv]
|
||||
assert.False(t, ok, "service_env_vars should not be added")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sensitive NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safe NB vars anonymized when anonymize is true",
|
||||
anonymize: true,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
"NB_SETUP_KEY": "secret",
|
||||
"SOME_OTHER": "val",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
// Safe NB_ values should be anonymized (not the original, not masked)
|
||||
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
||||
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
||||
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
||||
|
||||
logVal := env["NB_LOG_LEVEL"].(string)
|
||||
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
||||
|
||||
// Sensitive and non-NB_ still masked
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
||||
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
g := &BundleGenerator{
|
||||
anonymize: tt.anonymize,
|
||||
anonymizer: anonymizer,
|
||||
}
|
||||
g.sanitizeServiceEnvVars(tt.input)
|
||||
tt.check(t, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddServiceParams(t *testing.T) {
|
||||
t.Run("missing service.json returns nil", func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
}
|
||||
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = t.TempDir()
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
err := g.addServiceParams()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_LOG_LEVEL": "trace",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: true,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, zr.File, 1)
|
||||
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
mgmt := result[jsonKeyManagementURL].(string)
|
||||
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
||||
assert.NotEmpty(t, mgmt)
|
||||
|
||||
env := result[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
||||
})
|
||||
|
||||
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: false,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
|
||||
@@ -73,9 +73,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
if m.MsgHdr.Truncated {
|
||||
w.SetMeta("truncated", "true")
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -198,14 +195,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
fields := log.Fields{
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
})
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
@@ -268,9 +261,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@@ -24,7 +22,6 @@ import (
|
||||
|
||||
const (
|
||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
@@ -38,14 +35,6 @@ const (
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
|
||||
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||
maxDomainsPerResolverEntry = 50
|
||||
|
||||
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||
maxDomainBytesPerResolverEntry = 1500
|
||||
)
|
||||
|
||||
type systemConfigurator struct {
|
||||
@@ -95,23 +84,28 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old match keys: %v", err)
|
||||
}
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing match domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(matchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old search keys: %v", err)
|
||||
}
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing search domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(searchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
@@ -155,7 +149,8 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
|
||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
if len(s.createdKeys) == 0 {
|
||||
return s.discoverExistingKeys()
|
||||
// return defaults for startup calls
|
||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.createdKeys))
|
||||
@@ -165,47 +160,6 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||
dnsKeys, err := getSystemDNSKeys()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get system DNS keys: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||
if strings.Contains(dnsKeys, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||
for i := 0; ; i++ {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
if !strings.Contains(dnsKeys, key) {
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// getSystemDNSKeys gets all DNS keys
|
||||
func getSystemDNSKeys() (string, error) {
|
||||
command := "list .*DNS\nquit\n"
|
||||
out, err := runSystemConfigCommand(command)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
@@ -230,11 +184,12 @@ func (s *systemConfigurator) addLocalDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add local dns state: %w", err)
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.createdKeys[localKey] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -325,77 +280,28 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||
if len(domains) == 0 {
|
||||
return nil
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
var batches [][]string
|
||||
var current []string
|
||||
currentBytes := 0
|
||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
for _, d := range domains {
|
||||
domainLen := len(d)
|
||||
newBytes := currentBytes + domainLen
|
||||
if currentBytes > 0 {
|
||||
newBytes++ // space separator
|
||||
}
|
||||
s.createdKeys[key] = struct{}{}
|
||||
|
||||
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||
batches = append(batches, current)
|
||||
current = nil
|
||||
currentBytes = 0
|
||||
}
|
||||
|
||||
current = append(current, d)
|
||||
if currentBytes > 0 {
|
||||
currentBytes += 1 + domainLen
|
||||
} else {
|
||||
currentBytes = domainLen
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
batches = append(batches, current)
|
||||
}
|
||||
|
||||
return batches
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeKeysContaining removes all created keys that contain the given substring.
|
||||
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||
var toRemove []string
|
||||
for key := range s.createdKeys {
|
||||
if strings.Contains(key, suffix) {
|
||||
toRemove = append(toRemove, key)
|
||||
}
|
||||
}
|
||||
var multiErr *multierror.Error
|
||||
for _, key := range toRemove {
|
||||
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(multiErr)
|
||||
}
|
||||
|
||||
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
|
||||
for i, batch := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
domainsStr := strings.Join(batch, " ")
|
||||
|
||||
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||
}
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -458,6 +364,7 @@ func (s *systemConfigurator) flushDNSCache() error {
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -52,22 +49,17 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
// Collect all created keys for cleanup verification
|
||||
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||
for key := range configurator.createdKeys {
|
||||
createdKeys = append(createdKeys, key)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
_ = removeTestDNSKey(localKey)
|
||||
}()
|
||||
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
@@ -91,223 +83,13 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||
func generateShortDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
label := ""
|
||||
n := i
|
||||
for {
|
||||
label = string(rune('a'+n%26)) + label
|
||||
n = n/26 - 1
|
||||
if n < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
domains = append(domains, label+".com")
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||
func generateLongDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||
t.Helper()
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||
out, err := cmd.Output()
|
||||
require.NoError(t, err, "scutil show should succeed")
|
||||
|
||||
var domains []string
|
||||
inArray := false
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||
inArray = true
|
||||
continue
|
||||
}
|
||||
if inArray {
|
||||
if line == "}" {
|
||||
break
|
||||
}
|
||||
// lines look like: "0 : a.com"
|
||||
parts := strings.SplitN(line, " : ", 2)
|
||||
if len(parts) == 2 {
|
||||
domains = append(domains, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NoError(t, scanner.Err())
|
||||
return domains
|
||||
}
|
||||
|
||||
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expectedCount int
|
||||
checkAllPresent bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
domains: nil,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "under_limit",
|
||||
domains: generateShortDomains(10),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "at_element_limit",
|
||||
domains: generateShortDomains(50),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "over_element_limit",
|
||||
domains: generateShortDomains(51),
|
||||
expectedCount: 2,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "triple_element_limit",
|
||||
domains: generateShortDomains(150),
|
||||
expectedCount: 3,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "long_domains_hit_byte_limit",
|
||||
domains: generateLongDomains(50),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_short_domains",
|
||||
domains: generateShortDomains(500),
|
||||
expectedCount: 10,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_long_domains",
|
||||
domains: generateLongDomains(500),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
batches := splitDomainsIntoBatches(tc.domains)
|
||||
|
||||
if tc.expectedCount > 0 {
|
||||
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||
}
|
||||
|
||||
// Verify each batch respects limits
|
||||
for i, batch := range batches {
|
||||
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||
"batch %d exceeds element limit", i)
|
||||
|
||||
totalBytes := 0
|
||||
for j, d := range batch {
|
||||
if j > 0 {
|
||||
totalBytes++
|
||||
}
|
||||
totalBytes += len(d)
|
||||
}
|
||||
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||
}
|
||||
|
||||
if tc.checkAllPresent {
|
||||
var all []string
|
||||
for _, batch := range batches {
|
||||
all = append(all, batch...)
|
||||
}
|
||||
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||
// and verifies all domains are readable across multiple scutil keys.
|
||||
func TestMatchDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
count int
|
||||
generator func(int) []string
|
||||
}{
|
||||
{"short_10", 10, generateShortDomains},
|
||||
{"short_50", 50, generateShortDomains},
|
||||
{"short_100", 100, generateShortDomains},
|
||||
{"short_200", 200, generateShortDomains},
|
||||
{"short_500", 500, generateShortDomains},
|
||||
{"long_10", 10, generateLongDomains},
|
||||
{"long_50", 50, generateLongDomains},
|
||||
{"long_100", 100, generateLongDomains},
|
||||
{"long_200", 200, generateLongDomains},
|
||||
{"long_500", 500, generateLongDomains},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for key := range configurator.createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
domains := tc.generator(tc.count)
|
||||
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||
|
||||
// Read back all domains from all batched keys
|
||||
var got []string
|
||||
for i := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "key %s should exist", key)
|
||||
|
||||
got = append(got, readDomainsFromKey(t, key)...)
|
||||
}
|
||||
|
||||
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||
assert.Equal(t, domains, got, "domains should match in order")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
@@ -376,15 +158,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for key := range configurator.createdKeys {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
// Also clean up old-format keys and local key in case they exist
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
|
||||
@@ -42,8 +42,6 @@ const (
|
||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
@@ -200,11 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
// Update count even on error to ensure cleanup covers partially created rules
|
||||
r.nrptEntryCount = count
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
@@ -242,33 +239,23 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
|
||||
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||
ruleIndex := 0
|
||||
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||
end := i + nrptMaxDomainsPerRule
|
||||
if end > len(domains) {
|
||||
end = len(domains)
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
batchDomains := domains[i:end]
|
||||
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||
|
||||
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||
}
|
||||
|
||||
// Increment immediately so the caller's cleanup path knows about this rule
|
||||
ruleIndex++
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
|
||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
@@ -277,8 +264,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
|
||||
return ruleIndex, nil
|
||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
||||
return len(domains), nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
@@ -38,60 +37,51 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
domains125 := make([]DomainConfig, 125)
|
||||
for i := 0; i < 125; i++ {
|
||||
domains125[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config125 := HostDNSConfig{
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains125,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config125, nil)
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify 3 NRPT rules exist
|
||||
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
|
||||
for i := 0; i < 3; i++ {
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
|
||||
domains75 := make([]DomainConfig, 75)
|
||||
for i := 0; i < 75; i++ {
|
||||
domains75[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config75 := HostDNSConfig{
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains75,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config75, nil)
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 NRPT rules exist
|
||||
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify rule 2 is cleaned up
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
@@ -107,106 +97,6 @@ func registryKeyExists(path string) (bool, error) {
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
// Clean up more entries to account for batching tests with many domains
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
|
||||
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
|
||||
func TestNRPTDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domainCount int
|
||||
expectedRuleCount int
|
||||
}{
|
||||
{
|
||||
name: "Less than 50 domains (single rule)",
|
||||
domainCount: 30,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Exactly 50 domains (single rule)",
|
||||
domainCount: 50,
|
||||
expectedRuleCount: 1,
|
||||
},
|
||||
{
|
||||
name: "51 domains (two rules)",
|
||||
domainCount: 51,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "100 domains (two rules)",
|
||||
domainCount: 100,
|
||||
expectedRuleCount: 2,
|
||||
},
|
||||
{
|
||||
name: "125 domains (three rules: 50+50+25)",
|
||||
domainCount: 125,
|
||||
expectedRuleCount: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Clean up before each subtest
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
// Generate domains
|
||||
domains := make([]DomainConfig, tc.domainCount)
|
||||
for i := 0; i < tc.domainCount; i++ {
|
||||
domains[i] = DomainConfig{
|
||||
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||
MatchOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: domains,
|
||||
}
|
||||
|
||||
err := cfg.applyDNSConfig(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that exactly expectedRuleCount rules were created
|
||||
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||
|
||||
// Verify all expected rules exist
|
||||
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||
}
|
||||
|
||||
// Verify no extra rules were created
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||
func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on GracefullyStop
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("GracefullyStop clears all state", func(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
|
||||
@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
}
|
||||
}
|
||||
|
||||
// Flow receiver domain is intentionally excluded from caching.
|
||||
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
||||
// causes TLS certificate verification failures on reconnect.
|
||||
if serverDomains.Flow != "" {
|
||||
domains = append(domains, serverDomains.Flow)
|
||||
}
|
||||
|
||||
for _, stun := range serverDomains.Stuns {
|
||||
if stun != "" {
|
||||
|
||||
@@ -391,8 +391,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
}
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
||||
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Flow: "github.com",
|
||||
}
|
||||
@@ -401,10 +400,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
|
||||
|
||||
finalDomains := resolver.GetCachedDomains()
|
||||
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
||||
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
|
||||
|
||||
domainStrings := make([]string, len(finalDomains))
|
||||
for i, d := range finalDomains {
|
||||
@@ -413,5 +412,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
assert.Contains(t, domainStrings, "example.org")
|
||||
assert.Contains(t, domainStrings, "google.com")
|
||||
assert.Contains(t, domainStrings, "cloudflare.com")
|
||||
assert.NotContains(t, domainStrings, "github.com")
|
||||
assert.Contains(t, domainStrings, "github.com")
|
||||
}
|
||||
|
||||
@@ -84,28 +84,3 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
||||
func (m *MockServer) SetFirewall(Firewall) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// EndBatch mock implementation of EndBatch from Server interface
|
||||
func (m *MockServer) EndBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// CancelBatch mock implementation of CancelBatch from Server interface
|
||||
func (m *MockServer) CancelBatch() {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
@@ -104,23 +104,3 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
|
||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
||||
var srcIP net.IP
|
||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
||||
}
|
||||
|
||||
var srcPort int
|
||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
||||
}
|
||||
|
||||
if srcIP == nil {
|
||||
return nil
|
||||
}
|
||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
||||
}
|
||||
|
||||
@@ -45,9 +45,6 @@ type IosDnsManager interface {
|
||||
type Server interface {
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
BeginBatch()
|
||||
EndBatch()
|
||||
CancelBatch()
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() netip.Addr
|
||||
@@ -57,8 +54,6 @@ type Server interface {
|
||||
ProbeAvailability()
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
SetRouteChecker(func(netip.Addr) bool)
|
||||
SetFirewall(Firewall)
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -92,7 +87,6 @@ type DefaultServer struct {
|
||||
currentConfigHash uint64
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
batchMode bool
|
||||
|
||||
mgmtCacheResolver *mgmt.Resolver
|
||||
|
||||
@@ -106,17 +100,12 @@ type DefaultServer struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
routeMatch func(netip.Addr) bool
|
||||
|
||||
probeMu sync.Mutex
|
||||
probeCancel context.CancelFunc
|
||||
probeWg sync.WaitGroup
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
Stop()
|
||||
ProbeAvailability(context.Context)
|
||||
ProbeAvailability()
|
||||
ID() types.HandlerID
|
||||
}
|
||||
|
||||
@@ -152,7 +141,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
@@ -187,16 +176,11 @@ func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
hostsDnsList []netip.AddrPort,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -237,14 +221,6 @@ func newDefaultServer(
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
// SetRouteChecker sets the function used by upstream resolvers to determine
|
||||
// whether an IP is routed through the tunnel.
|
||||
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.routeMatch = f
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||
// Any previously registered handler for the same domain and priority will be replaced.
|
||||
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
@@ -258,9 +234,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@@ -289,41 +263,9 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
if !s.batchMode {
|
||||
s.applyHostConfig()
|
||||
}
|
||||
}
|
||||
|
||||
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||
func (s *DefaultServer) BeginBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode enabled")
|
||||
s.batchMode = true
|
||||
}
|
||||
|
||||
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||
func (s *DefaultServer) EndBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode disabled, applying accumulated changes")
|
||||
s.batchMode = false
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
// CancelBatch cancels batch mode without applying accumulated changes.
|
||||
// This is useful when operations fail partway through and you want to
|
||||
// discard partial state rather than applying it.
|
||||
func (s *DefaultServer) CancelBatch() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
|
||||
s.batchMode = false
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||
|
||||
@@ -380,26 +322,9 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
||||
// This must be called before Initialize when using the listener-based service,
|
||||
// because the firewall is typically not available at construction time.
|
||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
||||
svc.listenerFlagLock.Lock()
|
||||
svc.firewall = fw
|
||||
svc.listenerFlagLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
if s.probeCancel != nil {
|
||||
s.probeCancel()
|
||||
}
|
||||
s.ctxCancel()
|
||||
s.probeMu.Unlock()
|
||||
s.probeWg.Wait()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
@@ -412,12 +337,8 @@ func (s *DefaultServer) Stop() {
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||
defer func() {
|
||||
if err := s.service.Stop(); err != nil {
|
||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
||||
}
|
||||
}()
|
||||
func (s *DefaultServer) disableDNS() error {
|
||||
defer s.service.Stop()
|
||||
|
||||
if s.isUsingNoopHostManager() {
|
||||
return nil
|
||||
@@ -520,8 +441,7 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
}
|
||||
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds.
|
||||
// If a previous probe is still running, it will be cancelled before starting a new one.
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||
skipProbe, err := strconv.ParseBool(val)
|
||||
@@ -534,52 +454,15 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
}
|
||||
}
|
||||
|
||||
s.probeMu.Lock()
|
||||
|
||||
// don't start probes on a stopped server
|
||||
if s.ctx.Err() != nil {
|
||||
s.probeMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// cancel any running probe
|
||||
if s.probeCancel != nil {
|
||||
s.probeCancel()
|
||||
s.probeCancel = nil
|
||||
}
|
||||
|
||||
// wait for the previous probe goroutines to finish while holding
|
||||
// the mutex so no other caller can start a new probe concurrently
|
||||
s.probeWg.Wait()
|
||||
|
||||
// start a new probe
|
||||
probeCtx, probeCancel := context.WithCancel(s.ctx)
|
||||
s.probeCancel = probeCancel
|
||||
|
||||
s.probeWg.Add(1)
|
||||
defer s.probeWg.Done()
|
||||
|
||||
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
|
||||
s.mux.Lock()
|
||||
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
handlers = append(handlers, mux.handler)
|
||||
}
|
||||
s.mux.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range handlers {
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
wg.Add(1)
|
||||
go func(h handlerWithStop) {
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
h.ProbeAvailability(probeCtx)
|
||||
}(handler)
|
||||
mux.ProbeAvailability()
|
||||
}(mux.handler)
|
||||
}
|
||||
|
||||
s.probeMu.Unlock()
|
||||
|
||||
wg.Wait()
|
||||
probeCancel()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
@@ -640,7 +523,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
// Always apply host config for management updates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
@@ -774,7 +656,6 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
|
||||
return
|
||||
}
|
||||
handler.routeMatch = s.routeMatch
|
||||
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == config.ServerIP {
|
||||
@@ -884,7 +765,6 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||
}
|
||||
handler.routeMatch = s.routeMatch
|
||||
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
@@ -1007,7 +887,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver goes down, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
@@ -1043,7 +922,6 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
// Always apply host config when nameserver reactivates, regardless of batch mode
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
@@ -1069,7 +947,6 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return
|
||||
}
|
||||
handler.routeMatch = s.routeMatch
|
||||
|
||||
handler.upstreamServers = maps.Keys(hostDNSServers)
|
||||
handler.deactivate = func(error) {}
|
||||
|
||||
@@ -18,12 +18,7 @@ func TestGetServerDns(t *testing.T) {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
mockSrvB, ok := srvB.(*MockServer)
|
||||
if !ok {
|
||||
t.Errorf("returned server is not a MockServer")
|
||||
}
|
||||
|
||||
if mockSrvB != srv {
|
||||
if srvB != srv {
|
||||
t.Errorf("mismatch dns instances")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -1065,13 +1065,13 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ProbeAvailability() {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
|
||||
@@ -4,25 +4,15 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
||||
// This is used when the DNS server cannot bind port 53 directly
|
||||
// and needs firewall rules to redirect traffic.
|
||||
type Firewall interface {
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop() error
|
||||
Stop()
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
|
||||
@@ -6,17 +6,12 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
@@ -35,33 +30,25 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
tcpServer *dns.Server
|
||||
listenIP netip.Addr
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
firewall Firewall
|
||||
tcpDNATConfigured bool
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
firewall: fw,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
tcpServer: &dns.Server{
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -82,86 +69,43 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
s.server.Addr = addr
|
||||
s.tcpServer.Addr = addr
|
||||
|
||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
||||
s.listenerIsRunning = true
|
||||
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
s.listenerFlagLock.Lock()
|
||||
unexpected := s.listenerIsRunning
|
||||
s.listenerIsRunning = false
|
||||
s.listenerFlagLock.Unlock()
|
||||
|
||||
if unexpected {
|
||||
if err := s.tcpServer.Shutdown(); err != nil {
|
||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
||||
}
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
||||
// a DNAT rule because eBPF only handles UDP.
|
||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
||||
} else {
|
||||
s.tcpDNATConfigured = true
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() error {
|
||||
func (s *serviceViaListener) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
s.listenerIsRunning = false
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
||||
}
|
||||
|
||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
||||
}
|
||||
|
||||
if s.tcpDNATConfigured && s.firewall != nil {
|
||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
s.tcpDNATConfigured = false
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -188,6 +132,12 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
@@ -236,28 +186,18 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
return false
|
||||
}
|
||||
if err := udpLn.Close(); err != nil {
|
||||
log.Debugf("close UDP probe listener: %s", err)
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
if err := tcpLn.Close(); err != nil {
|
||||
log.Debugf("close TCP probe listener: %s", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("192.0.2.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a service using a custom address to avoid needing root
|
||||
svc := newServiceViaListener(nil, nil, nil)
|
||||
svc.dnsMux.Handle(".", handler)
|
||||
|
||||
// Bind both transports up front to avoid TOCTOU races.
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
||||
}
|
||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
t.Skip("cannot bind TCP on same port, skipping")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
||||
svc.server.PacketConn = udpConn
|
||||
svc.tcpServer.Listener = tcpLn
|
||||
svc.listenIP = customIP
|
||||
svc.listenPort = port
|
||||
|
||||
go func() {
|
||||
if err := svc.server.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
svc.listenerIsRunning = true
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, svc.Stop())
|
||||
}()
|
||||
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Test UDP query
|
||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "UDP query should succeed")
|
||||
require.NotNil(t, udpResp)
|
||||
require.NotEmpty(t, udpResp.Answer)
|
||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
||||
|
||||
// Test TCP query
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "TCP query should succeed")
|
||||
require.NotNil(t, tcpResp)
|
||||
require.NotEmpty(t, tcpResp.Answer)
|
||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -11,7 +10,6 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -20,8 +18,7 @@ type ServiceViaMemory struct {
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP netip.Addr
|
||||
runtimePort int
|
||||
tcpDNS *tcpDNSServer
|
||||
tcpHookSet bool
|
||||
udpFilterHookID string
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
@@ -31,13 +28,14 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
|
||||
return &ServiceViaMemory{
|
||||
s := &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
@@ -48,8 +46,10 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.filterDNSTraffic(); err != nil {
|
||||
return fmt.Errorf("filter dns traffic: %w", err)
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
@@ -57,29 +57,19 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() error {
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter != nil {
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
if s.tcpHookSet {
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
s.tcpDNS.Stop()
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -98,18 +88,10 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return errors.New("DNS filter not initialized")
|
||||
}
|
||||
|
||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
||||
if s.tcpDNS == nil {
|
||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
||||
}
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
@@ -118,16 +100,12 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
if udpLayer == nil {
|
||||
return true
|
||||
}
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
@@ -135,30 +113,13 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
return true
|
||||
}
|
||||
|
||||
dev := s.wgInterface.GetDevice()
|
||||
if dev == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
writer := &responseWriter{
|
||||
remote: remoteAddrFromPacket(packet),
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: dev.Device,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(writer, msg)
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
tcpHook := func(packetData []byte) bool {
|
||||
s.tcpDNS.InjectPacket(packetData)
|
||||
return true
|
||||
}
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
||||
s.tcpHookSet = true
|
||||
}
|
||||
|
||||
return nil
|
||||
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
@@ -1,444 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTCPReceiveWindow = 8192
|
||||
dnsTCPMaxInFlight = 16
|
||||
dnsTCPIdleTimeout = 30 * time.Second
|
||||
dnsTCPReadTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
||||
// after a period of inactivity to conserve resources.
|
||||
type tcpDNSServer struct {
|
||||
mu sync.Mutex
|
||||
s *stack.Stack
|
||||
ep *dnsEndpoint
|
||||
mux *dns.ServeMux
|
||||
tunDev tun.Device
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
mtu uint16
|
||||
|
||||
running bool
|
||||
closed bool
|
||||
timerID uint64
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
||||
return &tcpDNSServer{
|
||||
mux: mux,
|
||||
tunDev: tunDev,
|
||||
ip: ip,
|
||||
port: port,
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
||||
// lock prevents a race where the idle timer could stop the stack between
|
||||
// start and delivery.
|
||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.running {
|
||||
if err := t.startLocked(); err != nil {
|
||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
||||
}
|
||||
t.resetTimerLocked()
|
||||
|
||||
ep := t.ep
|
||||
if ep == nil || ep.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
|
||||
// Stop tears down the gvisor stack and releases resources permanently.
|
||||
// After Stop, InjectPacket becomes a no-op.
|
||||
func (t *tcpDNSServer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.stopLocked()
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) startLocked() error {
|
||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
nicID := tcpip.NICID(1)
|
||||
ep := &dnsEndpoint{
|
||||
tunDev: t.tunDev,
|
||||
}
|
||||
ep.mtu.Store(uint32(t.mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
||||
PrefixLen: 32,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create default subnet: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
})
|
||||
|
||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
||||
t.handleTCPDNS(r)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
t.s = s
|
||||
t.ep = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) stopLocked() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
|
||||
if t.s != nil {
|
||||
t.s.Close()
|
||||
t.s.Wait()
|
||||
t.s = nil
|
||||
}
|
||||
t.ep = nil
|
||||
t.running = false
|
||||
|
||||
log.Debugf("TCP DNS stack stopped")
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) resetTimerLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
t.timerID++
|
||||
id := t.timerID
|
||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Only stop if this timer is still the active one.
|
||||
// A racing InjectPacket may have replaced it.
|
||||
if t.timerID != id {
|
||||
return
|
||||
}
|
||||
t.stopLocked()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Tracef("TCP DNS: close conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reset idle timer on activity
|
||||
t.mu.Lock()
|
||||
t.resetTimerLocked()
|
||||
t.mu.Unlock()
|
||||
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: id.LocalAddress.AsSlice(),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
remoteAddr := &net.TCPAddr{
|
||||
IP: id.RemoteAddress.AsSlice(),
|
||||
Port: int(id.RemotePort),
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := readTCPDNSMessage(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
writer := &tcpResponseWriter{
|
||||
conn: conn,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
t.mux.ServeDNS(writer, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
||||
type dnsEndpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
tunDev tun.Device
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
||||
|
||||
const tunPacketOffset = 40
|
||||
|
||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := data.AsSlice()
|
||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
||||
buf = append(buf, raw...)
|
||||
data.Release()
|
||||
|
||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
||||
type tcpResponseWriter struct {
|
||||
conn *gonet.TCPConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
||||
return w.localAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
|
||||
// DNS TCP: 2-byte length prefix + message
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
|
||||
if _, err = w.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
if _, err := w.conn.Write(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
||||
|
||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
||||
// DNS over TCP uses a 2-byte length prefix
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return nil, fmt.Errorf("read length: %w", err)
|
||||
}
|
||||
|
||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if msgLen == 0 || msgLen > 65535 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("unpack: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
||||
// Supports both IPv4 and IPv6.
|
||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
||||
if len(pkt) == 0 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
||||
}
|
||||
|
||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
||||
switch header.IPVersion(pkt) {
|
||||
case 4:
|
||||
return srcIPv4(pkt)
|
||||
case 6:
|
||||
return srcIPv6(pkt)
|
||||
default:
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
}
|
||||
|
||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv4MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv4(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, int(hdr.HeaderLength())
|
||||
}
|
||||
|
||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv6MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv6(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, header.IPv6MinimumSize
|
||||
}
|
||||
@@ -41,61 +41,10 @@ const (
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
}
|
||||
|
||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
||||
func dnsProtocolFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
|
||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
@@ -116,12 +65,10 @@ type upstreamResolverBase struct {
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
wg sync.WaitGroup
|
||||
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
statusRecorder *peer.Status
|
||||
routeMatch func(netip.Addr) bool
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
@@ -168,11 +115,6 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
|
||||
u.mutex.Lock()
|
||||
u.wg.Wait()
|
||||
u.mutex.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
@@ -189,16 +131,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
||||
// when the request came in over TCP.
|
||||
ctx := u.ctx
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
network := addr.Network()
|
||||
ctx = contextWithDNSProtocol(ctx, network)
|
||||
resutil.SetMeta(w, "protocol", network)
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
@@ -213,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -228,7 +161,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
@@ -238,17 +171,15 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
@@ -265,7 +196,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -282,13 +213,10 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
@@ -332,10 +260,16 @@ func formatFailures(failures []upstreamFailure) string {
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query
|
||||
if u.successCount.Load() > 0 {
|
||||
return
|
||||
@@ -345,39 +279,31 @@ func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errs *multierror.Error
|
||||
var errors *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
upstream := upstream
|
||||
|
||||
wg.Add(1)
|
||||
go func(upstream netip.AddrPort) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||
err := u.testNameserver(upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errs = multierror.Append(errs, err)
|
||||
mu.Unlock()
|
||||
errors = multierror.Append(errors, err)
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
success = true
|
||||
mu.Unlock()
|
||||
}(upstream)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errs.ErrorOrNil())
|
||||
u.disable(errors.ErrorOrNil())
|
||||
|
||||
if u.statusRecorder == nil {
|
||||
return
|
||||
@@ -413,7 +339,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||
if err := u.testNameserver(upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
@@ -425,22 +351,16 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
err := backoff.Retry(operation, exponentialBackOff)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.mutex.Lock()
|
||||
u.disabled = false
|
||||
u.mutex.Unlock()
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
@@ -463,11 +383,7 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.waitUntilResponse()
|
||||
}()
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
@@ -478,57 +394,23 @@ func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
if externalCtx != nil {
|
||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||
defer stop2()
|
||||
}
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
return int(opt.UDPSize())
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
// Note: the query could be sent out on an interface that is not ours,
|
||||
// but higher MTU settings could break truncation handling.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
client.UDPSize = maxUDPPayload
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
@@ -547,32 +429,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// TODO: if the upstream's truncated UDP response already contains more
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
client.Net = "tcp"
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
@@ -580,46 +455,18 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
// If request came in over TCP, go straight to TCP upstream
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than what we can read over UDP.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, nil
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
@@ -640,7 +487,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
|
||||
@@ -65,13 +65,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
} else {
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ProbeAvailability(context.TODO())
|
||||
resolver.ProbeAvailability()
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
@@ -475,298 +475,3 @@ func TestFormatFailures(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSProtocolContext(t *testing.T) {
|
||||
t.Run("roundtrip udp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("missing returns empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
||||
// Start a local DNS server that responds on TCP only
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Addr: "127.0.0.1:0",
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tcpServer.Listener = tcpLn
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// With TCP context, should connect directly via TCP without trying UDP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
||||
// UDP handler returns a truncated response to trigger TCP retry.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// TCP handler returns the full answer.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.3"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{
|
||||
PacketConn: udpPC,
|
||||
Net: "udp",
|
||||
Handler: udpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := udpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.2"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// TCP context: should skip UDP entirely and go directly to TCP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
||||
|
||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
||||
ctx2 := context.Background()
|
||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
||||
// truncates, the TCP response should NOT be truncated since the full
|
||||
// answer fits within the client's original buffer.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
// Add enough records to exceed MTU but fit within 4096
|
||||
for i := range 20 {
|
||||
m.Answer = append(m.Answer, &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
||||
})
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
||||
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
|
||||
// Client with large buffer: should get all records without truncation
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
||||
|
||||
// Client with small buffer: should get truncated response
|
||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r2.SetEdns0(512, false)
|
||||
|
||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm2)
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
@@ -263,28 +263,20 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
fields := log.Fields{
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
})
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
fields := log.Fields{
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
})
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
@@ -28,17 +28,15 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
@@ -46,21 +44,22 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -76,11 +75,13 @@ import (
|
||||
const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
// EngineConfig is a config for the Engine
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -142,18 +143,6 @@ type EngineConfig struct {
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// EngineServices holds the external service dependencies required by the Engine.
|
||||
type EngineServices struct {
|
||||
SignalClient signal.Client
|
||||
MgmClient mgm.Client
|
||||
RelayManager *relayClient.Manager
|
||||
StatusRecorder *peer.Status
|
||||
Checks []*mgmProto.Checks
|
||||
StateManager *statemanager.Manager
|
||||
UpdateManager *updater.Manager
|
||||
ClientMetrics *metrics.ClientMetrics
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
type Engine struct {
|
||||
// signal is a Signal Service client
|
||||
@@ -211,19 +200,19 @@ type Engine struct {
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updater.Manager
|
||||
updateManager *updatemanager.Manager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
@@ -233,13 +222,8 @@ type Engine struct {
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
|
||||
// clientMetrics collects and pushes metrics
|
||||
clientMetrics *metrics.ClientMetrics
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -256,32 +240,35 @@ type localIpUpdater interface {
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
stateManager *statemanager.Manager,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: signalClient,
|
||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||
mgmClient: mgmClient,
|
||||
relayManager: relayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -324,7 +311,7 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.SetDownloadOnly()
|
||||
e.updateManager.Stop()
|
||||
}
|
||||
|
||||
log.Info("cleaning up status recorder states")
|
||||
@@ -432,7 +419,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
@@ -502,17 +488,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
@@ -524,11 +499,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject firewall into DNS server now that it's available.
|
||||
// The DNS server is created before the firewall because the route manager
|
||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
||||
e.dnsServer.SetFirewall(e.firewall)
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
@@ -540,13 +510,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Start after interface is up since port may have been resolved from 0 or changed if occupied
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
|
||||
}()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -597,6 +560,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||
}
|
||||
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
@@ -824,30 +794,45 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
if e.updateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||
if autoUpdateSettings == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if autoUpdateSettings.Version == disableAutoUpdate {
|
||||
log.Infof("auto-update is disabled")
|
||||
e.updateManager.SetDownloadOnly()
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
e.updateManager = nil
|
||||
return
|
||||
}
|
||||
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||
return
|
||||
}
|
||||
|
||||
// Start manager if needed
|
||||
if e.updateManager == nil {
|
||||
log.Infof("starting auto-update manager")
|
||||
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e.updateManager = updateManager
|
||||
e.updateManager.Start(e.ctx)
|
||||
}
|
||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(started)
|
||||
log.Infof("sync finished in %s", duration)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||
log.Infof("sync finished in %s", time.Since(started))
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
@@ -858,7 +843,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
@@ -1023,11 +1008,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
}
|
||||
|
||||
// Cannot update the IP address without restarting the engine because
|
||||
// the firewall, route manager, and other components cache the old address
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address)
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
@@ -1095,7 +1079,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
@@ -1333,7 +1316,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
go e.dnsServer.ProbeAvailability()
|
||||
e.dnsServer.ProbeAvailability()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1550,13 +1534,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
PortForwardManager: e.portForwardManager,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1579,10 +1562,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
start := time.Now()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
gotLock := time.Since(start)
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -1606,8 +1587,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
@@ -1713,12 +1692,6 @@ func (e *Engine) close() {
|
||||
if e.rpManager != nil {
|
||||
_ = e.rpManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
@@ -1822,7 +1795,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
@@ -1847,28 +1820,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
// GetFirewallManager returns the firewall manager.
|
||||
// GetFirewallManager returns the firewall manager
|
||||
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||
return e.firewall
|
||||
}
|
||||
|
||||
// GetExposeManager returns the expose session manager.
|
||||
func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
// IsBlockInbound returns whether inbound connections are blocked.
|
||||
func (e *Engine) IsBlockInbound() bool {
|
||||
return e.config.BlockInbound
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -251,6 +251,9 @@ func TestEngine_SSH(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
@@ -260,13 +263,10 @@ func TestEngine_SSH(t *testing.T) {
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -428,18 +428,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
@@ -652,18 +647,13 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -822,18 +812,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
@@ -1029,18 +1014,13 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
@@ -1538,8 +1518,13 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1561,12 +1546,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
const (
|
||||
renewTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
PortAutoAssigned bool
|
||||
}
|
||||
|
||||
// Request holds the parameters for exposing a local service via the management server.
|
||||
// It is part of the embed API surface and exposed via a type alias.
|
||||
type Request struct {
|
||||
NamePrefix string
|
||||
Domain string
|
||||
Port uint16
|
||||
Protocol ProtocolType
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
ListenPort uint16
|
||||
}
|
||||
|
||||
type ManagementClient interface {
|
||||
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
|
||||
RenewExpose(ctx context.Context, domain string) error
|
||||
StopExpose(ctx context.Context, domain string) error
|
||||
}
|
||||
|
||||
// Manager handles expose session lifecycle via the management client.
|
||||
type Manager struct {
|
||||
mgmClient ManagementClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewManager creates a new expose Manager using the given management client.
|
||||
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
|
||||
return &Manager{mgmClient: mgmClient, ctx: ctx}
|
||||
}
|
||||
|
||||
// Expose creates a new expose session via the management server.
|
||||
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||
log.Infof("exposing service on port %d", req.Port)
|
||||
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("expose session created for %s", resp.Domain)
|
||||
|
||||
return fromClientExposeResponse(resp), nil
|
||||
}
|
||||
|
||||
// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs.
|
||||
// It is part of the embed API surface and exposed via a type alias.
|
||||
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
defer m.stop(domain)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("context canceled, stopping keep alive for %s", domain)
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := m.renew(ctx, domain); err != nil {
|
||||
log.Errorf("renewing expose session for %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renew extends the TTL of an active expose session.
|
||||
func (m *Manager) renew(ctx context.Context, domain string) error {
|
||||
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
|
||||
defer cancel()
|
||||
return m.mgmClient.RenewExpose(renewCtx, domain)
|
||||
}
|
||||
|
||||
// stop terminates an active expose session.
|
||||
func (m *Manager) stop(domain string) {
|
||||
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
|
||||
defer cancel()
|
||||
err := m.mgmClient.StopExpose(stopCtx, domain)
|
||||
if err != nil {
|
||||
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
func TestManager_Expose_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return &mgm.ExposeResponse{
|
||||
ServiceName: "my-service",
|
||||
ServiceURL: "https://my-service.example.com",
|
||||
Domain: "my-service.example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
result, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
|
||||
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
|
||||
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
|
||||
}
|
||||
|
||||
func TestManager_Expose_Error(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return nil, errors.New("permission denied")
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
_, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
|
||||
}
|
||||
|
||||
func TestManager_Renew_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
err := m.renew(context.Background(), "my-service.example.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestManager_Renew_Timeout(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
return ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(ctx, mock)
|
||||
err := m.renew(ctx, "my-service.example.com")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
req := &daemonProto.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
|
||||
Pin: "123456",
|
||||
Password: "secret",
|
||||
UserGroups: []string{"group1", "group2"},
|
||||
Domain: "custom.example.com",
|
||||
NamePrefix: "my-prefix",
|
||||
}
|
||||
|
||||
exposeReq := NewRequest(req)
|
||||
|
||||
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
|
||||
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtocolType represents the protocol used for exposing a service.
|
||||
type ProtocolType int
|
||||
|
||||
const (
|
||||
// ProtocolHTTP exposes the service as HTTP.
|
||||
ProtocolHTTP ProtocolType = 0
|
||||
// ProtocolHTTPS exposes the service as HTTPS.
|
||||
ProtocolHTTPS ProtocolType = 1
|
||||
// ProtocolTCP exposes the service as TCP.
|
||||
ProtocolTCP ProtocolType = 2
|
||||
// ProtocolUDP exposes the service as UDP.
|
||||
ProtocolUDP ProtocolType = 3
|
||||
// ProtocolTLS exposes the service as TLS.
|
||||
ProtocolTLS ProtocolType = 4
|
||||
)
|
||||
|
||||
// ParseProtocolType parses a protocol string into a ProtocolType.
|
||||
func ParseProtocolType(s string) (ProtocolType, error) {
|
||||
switch strings.ToLower(s) {
|
||||
case "http":
|
||||
return ProtocolHTTP, nil
|
||||
case "https":
|
||||
return ProtocolHTTPS, nil
|
||||
case "tcp":
|
||||
return ProtocolTCP, nil
|
||||
case "udp":
|
||||
return ProtocolUDP, nil
|
||||
case "tls":
|
||||
return ProtocolTLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
|
||||
}
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
|
||||
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
return &Request{
|
||||
Port: uint16(req.Port),
|
||||
Protocol: ProtocolType(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
Domain: req.Domain,
|
||||
NamePrefix: req.NamePrefix,
|
||||
ListenPort: uint16(req.ListenPort),
|
||||
}
|
||||
}
|
||||
|
||||
func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
return mgm.ExposeRequest{
|
||||
NamePrefix: req.NamePrefix,
|
||||
Domain: req.Domain,
|
||||
Port: req.Port,
|
||||
Protocol: int(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
ListenPort: req.ListenPort,
|
||||
}
|
||||
}
|
||||
|
||||
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||
return &Response{
|
||||
ServiceName: response.ServiceName,
|
||||
Domain: response.Domain,
|
||||
ServiceURL: response.ServiceURL,
|
||||
PortAutoAssigned: response.PortAutoAssigned,
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package metrics
|
||||
|
||||
// ConnectionType represents the type of peer connection
|
||||
type ConnectionType string
|
||||
|
||||
const (
|
||||
// ConnectionTypeICE represents a direct peer-to-peer connection using ICE
|
||||
ConnectionTypeICE ConnectionType = "ice"
|
||||
|
||||
// ConnectionTypeRelay represents a relayed connection
|
||||
ConnectionTypeRelay ConnectionType = "relay"
|
||||
)
|
||||
|
||||
// String returns the string representation of the connection type
|
||||
func (c ConnectionType) String() string {
|
||||
return string(c)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DeploymentType represents the type of NetBird deployment
|
||||
type DeploymentType int
|
||||
|
||||
const (
|
||||
// DeploymentTypeUnknown represents an unknown or uninitialized deployment type
|
||||
DeploymentTypeUnknown DeploymentType = iota
|
||||
|
||||
// DeploymentTypeCloud represents a cloud-hosted NetBird deployment
|
||||
DeploymentTypeCloud
|
||||
|
||||
// DeploymentTypeSelfHosted represents a self-hosted NetBird deployment
|
||||
DeploymentTypeSelfHosted
|
||||
)
|
||||
|
||||
// String returns the string representation of the deployment type
|
||||
func (d DeploymentType) String() string {
|
||||
switch d {
|
||||
case DeploymentTypeCloud:
|
||||
return "cloud"
|
||||
case DeploymentTypeSelfHosted:
|
||||
return "selfhosted"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineDeploymentType determines if the deployment is cloud or self-hosted
|
||||
// based on the management URL string
|
||||
func DetermineDeploymentType(managementURL string) DeploymentType {
|
||||
if managementURL == "" {
|
||||
return DeploymentTypeUnknown
|
||||
}
|
||||
|
||||
u, err := url.Parse(managementURL)
|
||||
if err != nil {
|
||||
return DeploymentTypeSelfHosted
|
||||
}
|
||||
|
||||
if strings.ToLower(u.Hostname()) == "api.netbird.io" {
|
||||
return DeploymentTypeCloud
|
||||
}
|
||||
|
||||
return DeploymentTypeSelfHosted
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvMetricsPushEnabled controls whether collected metrics are pushed to the backend.
|
||||
// Metrics collection itself is always active (for debug bundles).
|
||||
// Disabled by default. Set NB_METRICS_PUSH_ENABLED=true to enable push.
|
||||
EnvMetricsPushEnabled = "NB_METRICS_PUSH_ENABLED"
|
||||
|
||||
// EnvMetricsForceSending if set to true, skips remote configuration fetch and forces metric sending
|
||||
EnvMetricsForceSending = "NB_METRICS_FORCE_SENDING"
|
||||
|
||||
// EnvMetricsConfigURL is the environment variable to override the metrics push config ServerAddress
|
||||
EnvMetricsConfigURL = "NB_METRICS_CONFIG_URL"
|
||||
|
||||
// EnvMetricsServerURL is the environment variable to override the metrics server address.
|
||||
// When set, this takes precedence over the server_url from remote push config.
|
||||
EnvMetricsServerURL = "NB_METRICS_SERVER_URL"
|
||||
|
||||
// EnvMetricsInterval overrides the push interval from the remote config.
|
||||
// Only affects how often metrics are pushed; remote config availability
|
||||
// and version range checks are still respected.
|
||||
// Format: duration string like "1h", "30m", "4h"
|
||||
EnvMetricsInterval = "NB_METRICS_INTERVAL"
|
||||
|
||||
defaultMetricsConfigURL = "https://ingest.netbird.io/config"
|
||||
)
|
||||
|
||||
// IsMetricsPushEnabled returns true if metrics push is enabled via NB_METRICS_PUSH_ENABLED env var.
|
||||
// Disabled by default. Metrics collection is always active for debug bundles.
|
||||
func IsMetricsPushEnabled() bool {
|
||||
enabled, _ := strconv.ParseBool(os.Getenv(EnvMetricsPushEnabled))
|
||||
return enabled
|
||||
}
|
||||
|
||||
// getMetricsInterval returns the metrics push interval from NB_METRICS_INTERVAL env var.
|
||||
// Returns 0 if not set or invalid.
|
||||
func getMetricsInterval() time.Duration {
|
||||
intervalStr := os.Getenv(EnvMetricsInterval)
|
||||
if intervalStr == "" {
|
||||
return 0
|
||||
}
|
||||
interval, err := time.ParseDuration(intervalStr)
|
||||
if err != nil {
|
||||
log.Warnf("invalid metrics interval from env %q: %v", intervalStr, err)
|
||||
return 0
|
||||
}
|
||||
if interval <= 0 {
|
||||
log.Warnf("invalid metrics interval from env %q: must be positive", intervalStr)
|
||||
return 0
|
||||
}
|
||||
return interval
|
||||
}
|
||||
|
||||
func isForceSending() bool {
|
||||
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
|
||||
return force
|
||||
}
|
||||
|
||||
// getMetricsConfigURL returns the URL to fetch push configuration from
|
||||
func getMetricsConfigURL() string {
|
||||
if envURL := os.Getenv(EnvMetricsConfigURL); envURL != "" {
|
||||
return envURL
|
||||
}
|
||||
return defaultMetricsConfigURL
|
||||
}
|
||||
|
||||
// getMetricsServerURL returns the metrics server URL from NB_METRICS_SERVER_URL env var.
|
||||
// Returns nil if not set or invalid.
|
||||
func getMetricsServerURL() *url.URL {
|
||||
envURL := os.Getenv(EnvMetricsServerURL)
|
||||
if envURL == "" {
|
||||
return nil
|
||||
}
|
||||
parsed, err := url.ParseRequestURI(envURL)
|
||||
if err != nil || parsed.Host == "" {
|
||||
log.Warnf("invalid metrics server URL %q: must be an absolute HTTP(S) URL", envURL)
|
||||
return nil
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
log.Warnf("invalid metrics server URL %q: unsupported scheme %q", envURL, parsed.Scheme)
|
||||
return nil
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxSampleAge = 5 * 24 * time.Hour // drop samples older than 5 days
|
||||
maxBufferSize = 5 * 1024 * 1024 // drop oldest samples when estimated size exceeds 5 MB
|
||||
// estimatedSampleSize is a rough per-sample memory estimate (measurement + tags + fields + timestamp)
|
||||
estimatedSampleSize = 256
|
||||
)
|
||||
|
||||
// influxSample is a single InfluxDB line protocol entry.
|
||||
type influxSample struct {
|
||||
measurement string
|
||||
tags string
|
||||
fields map[string]float64
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// influxDBMetrics collects metric events as timestamped samples.
|
||||
// Each event is recorded with its exact timestamp, pushed once, then cleared.
|
||||
type influxDBMetrics struct {
|
||||
mu sync.Mutex
|
||||
samples []influxSample
|
||||
}
|
||||
|
||||
func newInfluxDBMetrics() metricsImplementation {
|
||||
return &influxDBMetrics{}
|
||||
}
|
||||
func (m *influxDBMetrics) RecordConnectionStages(
|
||||
_ context.Context,
|
||||
agentInfo AgentInfo,
|
||||
connectionPairID string,
|
||||
connectionType ConnectionType,
|
||||
isReconnection bool,
|
||||
timestamps ConnectionStageTimestamps,
|
||||
) {
|
||||
var signalingReceivedToConnection, connectionToWgHandshake, totalDuration float64
|
||||
|
||||
if !timestamps.SignalingReceived.IsZero() && !timestamps.ConnectionReady.IsZero() {
|
||||
signalingReceivedToConnection = timestamps.ConnectionReady.Sub(timestamps.SignalingReceived).Seconds()
|
||||
}
|
||||
|
||||
if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||
connectionToWgHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds()
|
||||
}
|
||||
|
||||
if !timestamps.SignalingReceived.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||
totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.SignalingReceived).Seconds()
|
||||
}
|
||||
|
||||
attemptType := "initial"
|
||||
if isReconnection {
|
||||
attemptType = "reconnection"
|
||||
}
|
||||
|
||||
connTypeStr := connectionType.String()
|
||||
tags := fmt.Sprintf("deployment_type=%s,connection_type=%s,attempt_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,connection_pair_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
connTypeStr,
|
||||
attemptType,
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
connectionPairID,
|
||||
)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_peer_connection",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"signaling_to_connection_seconds": signalingReceivedToConnection,
|
||||
"connection_to_wg_handshake_seconds": connectionToWgHandshake,
|
||||
"total_seconds": totalDuration,
|
||||
},
|
||||
timestamp: now,
|
||||
})
|
||||
m.trimLocked()
|
||||
|
||||
log.Tracef("peer connection metrics [%s, %s, %s]: signalingReceived→connection: %.3fs, connection→wg_handshake: %.3fs, total: %.3fs",
|
||||
agentInfo.DeploymentType.String(), connTypeStr, attemptType, signalingReceivedToConnection, connectionToWgHandshake, totalDuration)
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_sync",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"duration_seconds": duration.Seconds(),
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
result = "failure"
|
||||
}
|
||||
|
||||
tags := fmt.Sprintf("deployment_type=%s,result=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
result,
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_login",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"duration_seconds": duration.Seconds(),
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
|
||||
log.Tracef("login metrics [%s, %s]: duration=%.3fs", agentInfo.DeploymentType.String(), result, duration.Seconds())
|
||||
}
|
||||
|
||||
// Export writes pending samples in InfluxDB line protocol format.
|
||||
// Format: measurement,tag=val,tag=val field=val,field=val timestamp_ns
|
||||
func (m *influxDBMetrics) Export(w io.Writer) error {
|
||||
m.mu.Lock()
|
||||
samples := make([]influxSample, len(m.samples))
|
||||
copy(samples, m.samples)
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, s := range samples {
|
||||
if _, err := fmt.Fprintf(w, "%s,%s ", s.measurement, s.tags); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sortedKeys := slices.Sorted(maps.Keys(s.fields))
|
||||
first := true
|
||||
for _, k := range sortedKeys {
|
||||
if !first {
|
||||
if _, err := fmt.Fprint(w, ","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "%s=%g", k, s.fields[k]); err != nil {
|
||||
return err
|
||||
}
|
||||
first = false
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, " %d\n", s.timestamp.UnixNano()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset clears pending samples after a successful push
|
||||
func (m *influxDBMetrics) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.samples = m.samples[:0]
|
||||
}
|
||||
|
||||
// trimLocked removes samples that exceed age or size limits.
|
||||
// Must be called with m.mu held.
|
||||
func (m *influxDBMetrics) trimLocked() {
|
||||
now := time.Now()
|
||||
|
||||
// drop samples older than maxSampleAge
|
||||
cutoff := 0
|
||||
for cutoff < len(m.samples) && now.Sub(m.samples[cutoff].timestamp) > maxSampleAge {
|
||||
cutoff++
|
||||
}
|
||||
if cutoff > 0 {
|
||||
copy(m.samples, m.samples[cutoff:])
|
||||
m.samples = m.samples[:len(m.samples)-cutoff]
|
||||
log.Debugf("influxdb metrics: dropped %d samples older than %s", cutoff, maxSampleAge)
|
||||
}
|
||||
|
||||
// drop oldest samples if estimated size exceeds maxBufferSize
|
||||
maxSamples := maxBufferSize / estimatedSampleSize
|
||||
if len(m.samples) > maxSamples {
|
||||
drop := len(m.samples) - maxSamples
|
||||
copy(m.samples, m.samples[drop:])
|
||||
m.samples = m.samples[:maxSamples]
|
||||
log.Debugf("influxdb metrics: dropped %d oldest samples to stay under %d MB size limit", drop, maxBufferSize/(1024*1024))
|
||||
}
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInfluxDBMetrics_RecordAndExport(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
agentInfo := AgentInfo{
|
||||
DeploymentType: DeploymentTypeCloud,
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
peerID: "abc123",
|
||||
}
|
||||
|
||||
ts := ConnectionStageTimestamps{
|
||||
SignalingReceived: time.Now().Add(-3 * time.Second),
|
||||
ConnectionReady: time.Now().Add(-2 * time.Second),
|
||||
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
|
||||
}
|
||||
|
||||
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "netbird_peer_connection,")
|
||||
assert.Contains(t, output, "connection_to_wg_handshake_seconds=")
|
||||
assert.Contains(t, output, "signaling_to_connection_seconds=")
|
||||
assert.Contains(t, output, "total_seconds=")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_ExportDeterministicFieldOrder(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
agentInfo := AgentInfo{
|
||||
DeploymentType: DeploymentTypeCloud,
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
peerID: "abc123",
|
||||
}
|
||||
|
||||
ts := ConnectionStageTimestamps{
|
||||
SignalingReceived: time.Now().Add(-3 * time.Second),
|
||||
ConnectionReady: time.Now().Add(-2 * time.Second),
|
||||
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
|
||||
}
|
||||
|
||||
// Record multiple times and verify consistent field order
|
||||
for i := 0; i < 10; i++ {
|
||||
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
|
||||
require.Len(t, lines, 10)
|
||||
|
||||
// Extract field portion from each line and verify they're all identical
|
||||
var fieldSections []string
|
||||
for _, line := range lines {
|
||||
parts := strings.SplitN(line, " ", 3)
|
||||
require.Len(t, parts, 3, "each line should have measurement, fields, timestamp")
|
||||
fieldSections = append(fieldSections, parts[1])
|
||||
}
|
||||
|
||||
for i := 1; i < len(fieldSections); i++ {
|
||||
assert.Equal(t, fieldSections[0], fieldSections[i], "field order should be deterministic across samples")
|
||||
}
|
||||
|
||||
// Fields should be alphabetically sorted
|
||||
assert.True(t, strings.HasPrefix(fieldSections[0], "connection_to_wg_handshake_seconds="),
|
||||
"fields should be sorted: connection_to_wg < signaling_to < total")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_RecordSyncDuration(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
agentInfo := AgentInfo{
|
||||
DeploymentType: DeploymentTypeSelfHosted,
|
||||
Version: "2.0.0",
|
||||
OS: "darwin",
|
||||
Arch: "arm64",
|
||||
peerID: "def456",
|
||||
}
|
||||
|
||||
m.RecordSyncDuration(context.Background(), agentInfo, 1500*time.Millisecond)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "netbird_sync,")
|
||||
assert.Contains(t, output, "duration_seconds=1.5")
|
||||
assert.Contains(t, output, "deployment_type=selfhosted")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_Reset(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
agentInfo := AgentInfo{
|
||||
DeploymentType: DeploymentTypeCloud,
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
peerID: "abc123",
|
||||
}
|
||||
|
||||
m.RecordSyncDuration(context.Background(), agentInfo, time.Second)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, buf.String())
|
||||
|
||||
m.Reset()
|
||||
|
||||
buf.Reset()
|
||||
err = m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, buf.String(), "should be empty after reset")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_ExportEmpty(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, buf.String())
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_TrimByAge(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
m.mu.Lock()
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "old",
|
||||
tags: "t=1",
|
||||
fields: map[string]float64{"v": 1},
|
||||
timestamp: time.Now().Add(-maxSampleAge - time.Hour),
|
||||
})
|
||||
m.trimLocked()
|
||||
remaining := len(m.samples)
|
||||
m.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 0, remaining, "old samples should be trimmed")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_RecordLoginDuration(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
agentInfo := AgentInfo{
|
||||
DeploymentType: DeploymentTypeCloud,
|
||||
Version: "1.0.0",
|
||||
OS: "linux",
|
||||
Arch: "amd64",
|
||||
peerID: "abc123",
|
||||
}
|
||||
|
||||
m.RecordLoginDuration(context.Background(), agentInfo, 2500*time.Millisecond, true)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "netbird_login,")
|
||||
assert.Contains(t, output, "duration_seconds=2.5")
|
||||
assert.Contains(t, output, "result=success")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_RecordLoginDurationFailure(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
agentInfo := AgentInfo{
|
||||
DeploymentType: DeploymentTypeSelfHosted,
|
||||
Version: "1.0.0",
|
||||
OS: "darwin",
|
||||
Arch: "arm64",
|
||||
peerID: "xyz789",
|
||||
}
|
||||
|
||||
m.RecordLoginDuration(context.Background(), agentInfo, 5*time.Second, false)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := m.Export(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "netbird_login,")
|
||||
assert.Contains(t, output, "result=failure")
|
||||
assert.Contains(t, output, "deployment_type=selfhosted")
|
||||
}
|
||||
|
||||
func TestInfluxDBMetrics_TrimBySize(t *testing.T) {
|
||||
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||
|
||||
maxSamples := maxBufferSize / estimatedSampleSize
|
||||
m.mu.Lock()
|
||||
for i := 0; i < maxSamples+100; i++ {
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "test",
|
||||
tags: "t=1",
|
||||
fields: map[string]float64{"v": float64(i)},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
m.trimLocked()
|
||||
remaining := len(m.samples)
|
||||
m.mu.Unlock()
|
||||
|
||||
assert.Equal(t, maxSamples, remaining, "should trim to max samples")
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
# Copy to .env and adjust values before running docker compose
|
||||
|
||||
# InfluxDB admin (server-side only, never exposed to clients)
|
||||
INFLUXDB_ADMIN_PASSWORD=changeme
|
||||
INFLUXDB_ADMIN_TOKEN=changeme
|
||||
|
||||
# Grafana admin credentials
|
||||
GRAFANA_ADMIN_USER=admin
|
||||
GRAFANA_ADMIN_PASSWORD=changeme
|
||||
|
||||
# Remote config served by ingest at /config
|
||||
# Set CONFIG_METRICS_SERVER_URL to the ingest server's public address to enable
|
||||
CONFIG_METRICS_SERVER_URL=
|
||||
CONFIG_VERSION_SINCE=0.0.0
|
||||
CONFIG_VERSION_UNTIL=99.99.99
|
||||
CONFIG_PERIOD_MINUTES=5
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user