mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 14:44:34 -04:00
Compare commits
196 Commits
fix/handle
...
netmap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
840b07c784 | ||
|
|
85e991ff78 | ||
|
|
f9845e53a0 | ||
|
|
765aba2c1c | ||
|
|
7cb81f1d70 | ||
|
|
cea19de667 | ||
|
|
29e5eceb6b | ||
|
|
0f63737330 | ||
|
|
bf518c5fba | ||
|
|
eab6183a8e | ||
|
|
4517da8b3a | ||
|
|
9c0d923124 | ||
|
|
6857734c48 | ||
|
|
3b019800f8 | ||
|
|
4cd4f88666 | ||
|
|
d2157bda66 | ||
|
|
43a8ba97e3 | ||
|
|
17874771cc | ||
|
|
f6ccf6b97a | ||
|
|
6aae797baf | ||
|
|
aca054e51e | ||
|
|
10cee8f46e | ||
|
|
628673db20 | ||
|
|
eaa31c2dc6 | ||
|
|
25723e9b07 | ||
|
|
3cf4d5758f | ||
|
|
fc15ee6351 | ||
|
|
4a3e78fb0f | ||
|
|
f9462eea27 | ||
|
|
b075009ef7 | ||
|
|
c347a4c2ca | ||
|
|
61bc092458 | ||
|
|
b679404618 | ||
|
|
215fb257f7 | ||
|
|
381447b8d6 | ||
|
|
919c1cb3d4 | ||
|
|
85d17cbc89 | ||
|
|
c9f3854dde | ||
|
|
245b086646 | ||
|
|
1609b21b5b | ||
|
|
1f926d15b8 | ||
|
|
a432e8e23a | ||
|
|
4fec709bb1 | ||
|
|
95299be52d | ||
|
|
f51cae7103 | ||
|
|
f68d5e965f | ||
|
|
85b8f36ec1 | ||
|
|
94e505480b | ||
|
|
10d8617be6 | ||
|
|
deffe037aa | ||
|
|
983d7bafbe | ||
|
|
4da29451d0 | ||
|
|
9b3449753e | ||
|
|
456629811b | ||
|
|
c311d0d19e | ||
|
|
521f7dd39f | ||
|
|
f9ec0a9a2e | ||
|
|
012235ff12 | ||
|
|
f176807ebe | ||
|
|
d4c47eaf8a | ||
|
|
d35a79d3b5 | ||
|
|
6a2929011d | ||
|
|
e877c9d6c1 | ||
|
|
7a1c96ebf4 | ||
|
|
41fe9f84ec | ||
|
|
d13fb0e379 | ||
|
|
f3214527ea | ||
|
|
69048bfd34 | ||
|
|
29a2d93873 | ||
|
|
6b01b0020e | ||
|
|
9d3db68805 | ||
|
|
2e315311e0 | ||
|
|
67e2185964 | ||
|
|
89149dc6f4 | ||
|
|
5a1f8f13a2 | ||
|
|
e71059d245 | ||
|
|
91fa2e20a0 | ||
|
|
61034aaf4d | ||
|
|
b8717b8956 | ||
|
|
50201d63c2 | ||
|
|
d11b39282b | ||
|
|
bd58eea8ea | ||
|
|
a5811a2d7d | ||
|
|
a680f80ed9 | ||
|
|
10fbdc2c4a | ||
|
|
1444fbe104 | ||
|
|
650bca7ca8 | ||
|
|
570e28d227 | ||
|
|
272ade07a8 | ||
|
|
263abe4862 | ||
|
|
ceee421a05 | ||
|
|
0a75da6fb7 | ||
|
|
920877964f | ||
|
|
2e0047daea | ||
|
|
ce0718fcb5 | ||
|
|
c590518e0c | ||
|
|
f309b120cd | ||
|
|
7357a9954c | ||
|
|
13b63eebc1 | ||
|
|
735ed7ab34 | ||
|
|
961d9198ef | ||
|
|
df4ca01848 | ||
|
|
4e7c17756c | ||
|
|
6a4935139d | ||
|
|
35dd991776 | ||
|
|
3598418206 | ||
|
|
e435e39158 | ||
|
|
fd26e989e3 | ||
|
|
4424162bce | ||
|
|
54b045d9ca | ||
|
|
71c6437bab | ||
|
|
7b254cb966 | ||
|
|
8f3a0f2c38 | ||
|
|
1f33e2e003 | ||
|
|
1e6addaa65 | ||
|
|
f51dc13f8c | ||
|
|
3477108ce7 | ||
|
|
012e624296 | ||
|
|
4c5e987e02 | ||
|
|
a80c8b0176 | ||
|
|
9e01155d2e | ||
|
|
3c3111ad01 | ||
|
|
b74078fd95 | ||
|
|
77488ad11a | ||
|
|
e3b76448f3 | ||
|
|
e0de86d6c9 | ||
|
|
5204d07811 | ||
|
|
5ea24ba56e | ||
|
|
d30cf8706a | ||
|
|
15a2feb723 | ||
|
|
91b2f9fc51 | ||
|
|
76702c8a09 | ||
|
|
061f673a4f | ||
|
|
9505805313 | ||
|
|
704c67dec8 | ||
|
|
3ed2f08f3c | ||
|
|
4c83408f27 | ||
|
|
90bd39c740 | ||
|
|
dd0cf41147 | ||
|
|
22b2caffc6 | ||
|
|
c1f66d1354 | ||
|
|
ac0fe6025b | ||
|
|
c28657710a | ||
|
|
3875c29f6b | ||
|
|
9f32ccd453 | ||
|
|
1d1d057e7d | ||
|
|
3461b1bb90 | ||
|
|
3d2a2377c6 | ||
|
|
25f5f26527 | ||
|
|
bb0d5c5baf | ||
|
|
7938295190 | ||
|
|
9af532fe71 | ||
|
|
23a1473797 | ||
|
|
9c2dc05df1 | ||
|
|
40d56e5d29 | ||
|
|
fd23d0c28f | ||
|
|
4fff93a1f2 | ||
|
|
22beac1b1b | ||
|
|
bd7a65d798 | ||
|
|
2d76b058fc | ||
|
|
ea2d060f93 | ||
|
|
68b377a28c | ||
|
|
af50eb350f | ||
|
|
2475473227 | ||
|
|
846871913d | ||
|
|
6cba9c0818 | ||
|
|
f0672b87bc | ||
|
|
9b0fe2c8e5 | ||
|
|
abd57d1191 | ||
|
|
416f04c27a | ||
|
|
fc7c1e397f | ||
|
|
52a3ac6b06 | ||
|
|
0b3b50c705 | ||
|
|
042141db06 | ||
|
|
4a1aee1ae0 | ||
|
|
ba33572ec9 | ||
|
|
9d213e0b54 | ||
|
|
5dde044fa5 | ||
|
|
5a3d9e401f | ||
|
|
fde1a2196c | ||
|
|
0aeb87742a | ||
|
|
6d747b2f83 | ||
|
|
199bf73103 | ||
|
|
17f5abc653 | ||
|
|
aa935bdae3 | ||
|
|
452419c4c3 | ||
|
|
17b1099032 | ||
|
|
a4b9e93217 | ||
|
|
63d7957140 | ||
|
|
9a6814deff | ||
|
|
190698bcf2 | ||
|
|
468fa2940b | ||
|
|
79a0647a26 | ||
|
|
17ceb3bde8 | ||
|
|
5a8f1763a6 | ||
|
|
f64e73ca70 |
2
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
2
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -2,7 +2,7 @@
|
||||
name: Bug/Issue report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ['triage']
|
||||
labels: ['triage-needed']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
5
.github/workflows/golang-test-darwin.yml
vendored
5
.github/workflows/golang-test-darwin.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['jsonfile', 'sqlite']
|
||||
store: ['sqlite']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -32,6 +32,9 @@ jobs:
|
||||
restore-keys: |
|
||||
macos-go-
|
||||
|
||||
- name: Install libpcap
|
||||
run: brew install libpcap
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
|
||||
39
.github/workflows/golang-test-freebsd.yml
vendored
Normal file
39
.github/workflows/golang-test-freebsd.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
|
||||
name: Test Code FreeBSD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
prepare: |
|
||||
pkg install -y curl
|
||||
pkg install -y git
|
||||
|
||||
run: |
|
||||
set -x
|
||||
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
|
||||
tar zxf go.tar.gz
|
||||
mv go /usr/local/go
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
go mod tidy
|
||||
go test -timeout 5m -p 1 ./iface/...
|
||||
go test -timeout 5m -p 1 ./client/...
|
||||
cd client
|
||||
go build .
|
||||
cd ..
|
||||
22
.github/workflows/golang-test-linux.yml
vendored
22
.github/workflows/golang-test-linux.yml
vendored
@@ -14,8 +14,8 @@ jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['386','amd64']
|
||||
store: ['jsonfile', 'sqlite']
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -36,7 +36,11 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -67,7 +71,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -82,7 +86,10 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
@@ -104,12 +111,15 @@ jobs:
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
||||
3
.github/workflows/golang-test-windows.yml
vendored
3
.github/workflows/golang-test-windows.yml
vendored
@@ -44,10 +44,9 @@ jobs:
|
||||
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
- run: "[Environment]::SetEnvironmentVariable('NETBIRD_STORE_ENGINE', 'jsonfile', 'Machine')"
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
8
.github/workflows/golangci-lint.yml
vendored
8
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta
|
||||
ignore_words_list: erro,clienta,hastable,
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
@@ -33,6 +33,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
@@ -40,7 +44,7 @@ jobs:
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
|
||||
12
.github/workflows/mobile-build-validation.yml
vendored
12
.github/workflows/mobile-build-validation.yml
vendored
@@ -11,7 +11,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
andrloid_build:
|
||||
android_build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -38,10 +38,10 @@ jobs:
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android nebtird lib
|
||||
- name: build android netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
@@ -56,10 +56,10 @@ jobs:
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS nebtird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
35
.github/workflows/release.yml
vendored
35
.github/workflows/release.yml
vendored
@@ -7,17 +7,7 @@ on:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.goreleaser.yml'
|
||||
- '.goreleaser_ui.yaml'
|
||||
- '.goreleaser_ui_darwin.yaml'
|
||||
- '.github/workflows/release.yml'
|
||||
- 'release_files/**'
|
||||
- '**/Dockerfile'
|
||||
- '**/Dockerfile.*'
|
||||
- 'client/ui/**'
|
||||
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.11"
|
||||
@@ -106,6 +96,27 @@ jobs:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload linux packages
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload windows packages
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload macos packages
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
retention-days: 3
|
||||
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -162,7 +173,7 @@ jobs:
|
||||
retention-days: 3
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-11
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
66
.github/workflows/test-infrastructure-files.yml
vendored
66
.github/workflows/test-infrastructure-files.yml
vendored
@@ -162,6 +162,13 @@ jobs:
|
||||
test $count -eq 4
|
||||
working-directory: infrastructure_files/artifacts
|
||||
|
||||
- name: test geolocation databases
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -171,34 +178,79 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
- name: test Caddy file gen
|
||||
- name: test Caddy file gen postgres
|
||||
run: test -f Caddyfile
|
||||
- name: test docker-compose file gen
|
||||
|
||||
- name: test docker-compose file gen postgres
|
||||
run: test -f docker-compose.yml
|
||||
- name: test management.json file gen
|
||||
|
||||
- name: test management.json file gen postgres
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
|
||||
- name: test turnserver.conf file gen postgres
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
|
||||
- name: test zitadel.env file gen postgres
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
|
||||
- name: test dashboard.env file gen postgres
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test zdb.env file gen postgres
|
||||
run: test -f zdb.env
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
env:
|
||||
NETBIRD_DOMAIN: use-ip
|
||||
ZITADEL_DATABASE: cockroach
|
||||
|
||||
- name: test Caddy file gen CockroachDB
|
||||
run: test -f Caddyfile
|
||||
|
||||
- name: test docker-compose file gen CockroachDB
|
||||
run: test -f docker-compose.yml
|
||||
|
||||
- name: test management.json file gen CockroachDB
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen CockroachDB
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
|
||||
- name: test zitadel.env file gen CockroachDB
|
||||
run: test -f zitadel.env
|
||||
|
||||
- name: test dashboard.env file gen CockroachDB
|
||||
run: test -f dashboard.env
|
||||
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
||||
@@ -63,6 +63,14 @@ linters-settings:
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
severity: warning
|
||||
disabled: false
|
||||
arguments:
|
||||
- "checkPrivateReceivers"
|
||||
- "sayRepetitiveInsteadOfStutters"
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
@@ -93,6 +101,7 @@ linters:
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
@@ -121,3 +130,10 @@ issues:
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
# Exclude specific deprecation warnings for grpc methods
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.DialContext is deprecated"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.WithBlock is deprecated"
|
||||
|
||||
@@ -3,8 +3,10 @@ builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
binary: netbird-ui
|
||||
env: [CGO_ENABLED=1]
|
||||
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
- MACOS_DEPLOYMENT_TARGET=11.0
|
||||
goos:
|
||||
- darwin
|
||||
goarch:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
identity and expression, level of experience, education, socioeconomic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
|
||||
45
README.md
45
README.md
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
|
||||
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
|
||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
@@ -40,27 +40,26 @@
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||
### Open-Source Network Security in a Single Platform
|
||||
|
||||
|
||||

|
||||
|
||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Automation | Platforms |
|
||||
|---------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| <ul><li> - \[x] Post-quantum-secure connection through [Rosenpass](https://rosenpass.eu) </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
||||
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
||||
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
@@ -79,7 +78,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
@@ -96,9 +95,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
@@ -109,8 +108,8 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
@@ -122,7 +121,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
FROM alpine:3.18.5
|
||||
FROM alpine:3.19
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||
COPY netbird /go/bin/netbird
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -54,14 +57,17 @@ type Client struct {
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
uiVersion: uiVersion,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
@@ -84,6 +90,9 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||
//nolint
|
||||
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
|
||||
|
||||
c.ctxCancelLock.Lock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
defer c.ctxCancel()
|
||||
@@ -97,7 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -122,7 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
212
client/anonymize/anonymize.go
Normal file
212
client/anonymize/anonymize.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package anonymize
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Anonymizer struct {
|
||||
ipAnonymizer map[netip.Addr]netip.Addr
|
||||
domainAnonymizer map[string]string
|
||||
currentAnonIPv4 netip.Addr
|
||||
currentAnonIPv6 netip.Addr
|
||||
startAnonIPv4 netip.Addr
|
||||
startAnonIPv6 netip.Addr
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 192.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
}
|
||||
|
||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
return &Anonymizer{
|
||||
ipAnonymizer: map[netip.Addr]netip.Addr{},
|
||||
domainAnonymizer: map[string]string{},
|
||||
currentAnonIPv4: startIPv4,
|
||||
currentAnonIPv6: startIPv6,
|
||||
startAnonIPv4: startIPv4,
|
||||
startAnonIPv6: startIPv6,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
if ip.IsLoopback() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsInterfaceLocalMulticast() ||
|
||||
ip.IsPrivate() ||
|
||||
ip.IsUnspecified() ||
|
||||
ip.IsMulticast() ||
|
||||
isWellKnown(ip) ||
|
||||
a.isInAnonymizedRange(ip) {
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
if _, ok := a.ipAnonymizer[ip]; !ok {
|
||||
if ip.Is4() {
|
||||
a.ipAnonymizer[ip] = a.currentAnonIPv4
|
||||
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
|
||||
} else {
|
||||
a.ipAnonymizer[ip] = a.currentAnonIPv6
|
||||
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
|
||||
}
|
||||
}
|
||||
return a.ipAnonymizer[ip]
|
||||
}
|
||||
|
||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||
return true
|
||||
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return ip
|
||||
}
|
||||
|
||||
return a.AnonymizeIP(addr).String()
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
||||
if strings.HasSuffix(domain, "netbird.io") ||
|
||||
strings.HasSuffix(domain, "netbird.selfhosted") ||
|
||||
strings.HasSuffix(domain, "netbird.cloud") ||
|
||||
strings.HasSuffix(domain, "netbird.stage") ||
|
||||
strings.HasSuffix(domain, ".domain") {
|
||||
return domain
|
||||
}
|
||||
|
||||
parts := strings.Split(domain, ".")
|
||||
if len(parts) < 2 {
|
||||
return domain
|
||||
}
|
||||
|
||||
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||
|
||||
anonymized, ok := a.domainAnonymizer[baseDomain]
|
||||
if !ok {
|
||||
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
|
||||
a.domainAnonymizer[baseDomain] = anonymizedBase
|
||||
anonymized = anonymizedBase
|
||||
}
|
||||
|
||||
return strings.Replace(domain, baseDomain, anonymized, 1)
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return uri
|
||||
}
|
||||
|
||||
var anonymizedHost string
|
||||
if u.Opaque != "" {
|
||||
host, port, err := net.SplitHostPort(u.Opaque)
|
||||
if err == nil {
|
||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||
} else {
|
||||
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
||||
}
|
||||
u.Opaque = anonymizedHost
|
||||
} else if u.Host != "" {
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err == nil {
|
||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||
} else {
|
||||
anonymizedHost = a.AnonymizeDomain(u.Host)
|
||||
}
|
||||
u.Host = anonymizedHost
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeString(str string) string {
|
||||
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
|
||||
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
|
||||
|
||||
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
|
||||
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
|
||||
|
||||
for domain, anonDomain := range a.domainAnonymizer {
|
||||
str = strings.ReplaceAll(str, domain, anonDomain)
|
||||
}
|
||||
|
||||
str = a.AnonymizeSchemeURI(str)
|
||||
str = a.AnonymizeDNSLogLine(str)
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
||||
|
||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||
}
|
||||
|
||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
||||
domainRegex := regexp.MustCompile(domainPattern)
|
||||
|
||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||
parts := strings.Split(match, `"`)
|
||||
if len(parts) >= 2 {
|
||||
domain := parts[1]
|
||||
if strings.HasSuffix(domain, ".domain") {
|
||||
return match
|
||||
}
|
||||
randomDomain := generateRandomString(10) + ".domain"
|
||||
return strings.Replace(match, domain, randomDomain, 1)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
func isWellKnown(addr netip.Addr) bool {
|
||||
wellKnown := []string{
|
||||
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
|
||||
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
|
||||
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
|
||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||
}
|
||||
|
||||
if slices.Contains(wellKnown, addr.String()) {
|
||||
return true
|
||||
}
|
||||
|
||||
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
|
||||
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
|
||||
|
||||
return cgnatRange.Contains(addr)
|
||||
}
|
||||
|
||||
func generateRandomString(length int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result[i] = letters[num.Int64()]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
223
client/anonymize/anonymize_test.go
Normal file
223
client/anonymize/anonymize_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package anonymize_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
)
|
||||
|
||||
func TestAnonymizeIP(t *testing.T) {
|
||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||
startIPv6 := netip.MustParseAddr("100::")
|
||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expect string
|
||||
}{
|
||||
{"Well known", "8.8.8.8", "8.8.8.8"},
|
||||
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Second Public IPv6", "a::b", "100::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := netip.MustParseAddr(tc.ip)
|
||||
anonymizedIP := anonymizer.AnonymizeIP(ip)
|
||||
if anonymizedIP.String() != tc.expect {
|
||||
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
||||
|
||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
||||
require.NotEqual(t, testLog, result)
|
||||
assert.NotContains(t, result, "example.com")
|
||||
}
|
||||
|
||||
func TestAnonymizeDomain(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
expectPattern string
|
||||
shouldAnonymize bool
|
||||
}{
|
||||
{
|
||||
"General Domain",
|
||||
"example.com",
|
||||
`^anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Subdomain",
|
||||
"sub.example.com",
|
||||
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Protected Domain",
|
||||
"netbird.io",
|
||||
`^netbird\.io$`,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeDomain(tc.domain)
|
||||
if tc.shouldAnonymize {
|
||||
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
|
||||
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
|
||||
} else {
|
||||
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeURI(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
regex string
|
||||
}{
|
||||
{
|
||||
"HTTP URI with Port",
|
||||
"http://example.com:80/path",
|
||||
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
|
||||
},
|
||||
{
|
||||
"HTTP URI without Port",
|
||||
"http://example.com/path",
|
||||
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
|
||||
},
|
||||
{
|
||||
"Opaque URI with Port",
|
||||
"stun:example.com:80?transport=udp",
|
||||
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
|
||||
},
|
||||
{
|
||||
"Opaque URI without Port",
|
||||
"stun:example.com?transport=udp",
|
||||
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeURI(tc.uri)
|
||||
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
|
||||
require.NotContains(t, result, "example.com", "Original domain should not be present")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeSchemeURI(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeSchemeURI(tc.input)
|
||||
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
|
||||
require.NotContains(t, result, "example.com", "Original domain should not be present")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizString_MemorizedDomain(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
domain := "example.com"
|
||||
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
|
||||
|
||||
sampleString := "This is a test string including the domain example.com which should be anonymized."
|
||||
|
||||
firstPassResult := anonymizer.AnonymizeString(sampleString)
|
||||
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
|
||||
|
||||
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
|
||||
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
|
||||
|
||||
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
|
||||
}
|
||||
|
||||
func TestAnonymizeString_DoubleURI(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
domain := "example.com"
|
||||
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
|
||||
|
||||
sampleString := "Check out our site at https://example.com for more info."
|
||||
|
||||
firstPassResult := anonymizer.AnonymizeString(sampleString)
|
||||
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
|
||||
|
||||
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
|
||||
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
|
||||
|
||||
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
|
||||
}
|
||||
|
||||
func TestAnonymizeString_IPAddresses(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 Address",
|
||||
input: "Error occurred at IP 122.138.1.1",
|
||||
expect: "Error occurred at IP 198.51.100.0",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address",
|
||||
input: "Access attempted from 2001:db8::ff00:42",
|
||||
expect: "Access attempted from 100::",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address with Port",
|
||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||
expect: "Access attempted from [100::]:8080",
|
||||
},
|
||||
{
|
||||
name: "Both IPv4 and IPv6",
|
||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeString(tc.input)
|
||||
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
|
||||
})
|
||||
}
|
||||
}
|
||||
255
client/cmd/debug.go
Normal file
255
client/cmd/debug.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
|
||||
}
|
||||
|
||||
var debugBundleCmd = &cobra.Command{
|
||||
Use: "bundle",
|
||||
Example: " netbird debug bundle",
|
||||
Short: "Create a debug bundle",
|
||||
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
|
||||
RunE: debugBundle,
|
||||
}
|
||||
|
||||
var logCmd = &cobra.Command{
|
||||
Use: "log",
|
||||
Short: "Manage logging for the Netbird daemon",
|
||||
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
|
||||
}
|
||||
|
||||
var logLevelCmd = &cobra.Command{
|
||||
Use: "level <level>",
|
||||
Short: "Set the logging level for this session",
|
||||
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
|
||||
Available log levels are:
|
||||
panic: for panic level, highest level of severity
|
||||
fatal: for fatal level errors that cause the program to exit
|
||||
error: for error conditions
|
||||
warn: for warning conditions
|
||||
info: for informational messages
|
||||
debug: for debug-level messages
|
||||
trace: for trace-level messages, which include more fine-grained information than debug`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: setLogLevel,
|
||||
}
|
||||
|
||||
var forCmd = &cobra.Command{
|
||||
Use: "for <time>",
|
||||
Short: "Run debug logs for a specified duration and create a debug bundle",
|
||||
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
|
||||
Example: " netbird debug for 5m",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runForDuration,
|
||||
}
|
||||
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
|
||||
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
|
||||
Level: level,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Log level set successfully to", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
duration, err := time.ParseDuration(args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration format: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
if !initialLevelTrace {
|
||||
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
|
||||
Level: proto.LogLevel_TRACE,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if restoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
elapsed := time.Since(startTime)
|
||||
if elapsed >= duration {
|
||||
return
|
||||
}
|
||||
remaining := duration - elapsed
|
||||
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
d = d.Round(time.Second)
|
||||
h := d / time.Hour
|
||||
d %= time.Hour
|
||||
m := d / time.Minute
|
||||
d %= time.Minute
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
@@ -2,9 +2,10 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
|
||||
@@ -32,8 +32,11 @@ const (
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,8 +64,14 @@ var (
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
serviceName string
|
||||
autoConnectDisabled bool
|
||||
rootCmd = &cobra.Command{
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
Long: "",
|
||||
@@ -100,15 +109,24 @@ func init() {
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
}
|
||||
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
@@ -116,8 +134,20 @@ func init() {
|
||||
rootCmd.AddCommand(loginCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(routesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
|
||||
routesCmd.AddCommand(routesListCmd)
|
||||
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
debugCmd.AddCommand(debugBundleCmd)
|
||||
debugCmd.AddCommand(logCmd)
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
@@ -325,3 +355,17 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
174
client/cmd/route.go
Normal file
174
client/cmd/route.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var appendFlag bool
|
||||
|
||||
var routesCmd = &cobra.Command{
|
||||
Use: "routes",
|
||||
Short: "Manage network routes",
|
||||
Long: `Commands to list, select, or deselect network routes.`,
|
||||
}
|
||||
|
||||
var routesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List routes",
|
||||
Example: " netbird routes list",
|
||||
Long: "List all available network routes.",
|
||||
RunE: routesList,
|
||||
}
|
||||
|
||||
var routesSelectCmd = &cobra.Command{
|
||||
Use: "select route...|all",
|
||||
Short: "Select routes",
|
||||
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
|
||||
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: routesSelect,
|
||||
}
|
||||
|
||||
var routesDeselectCmd = &cobra.Command{
|
||||
Use: "deselect route...|all",
|
||||
Short: "Deselect routes",
|
||||
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
|
||||
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: routesDeselect,
|
||||
}
|
||||
|
||||
func init() {
|
||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
|
||||
}
|
||||
|
||||
func routesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.Routes) == 0 {
|
||||
cmd.Println("No routes available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
printRoutes(cmd, resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
printRoute(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
if len(domains) > 0 {
|
||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||
} else {
|
||||
printNetworkRoute(cmd, route, selectedStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func getSelectedStatus(route *proto.Route) string {
|
||||
if route.GetSelected() {
|
||||
return "Selected"
|
||||
}
|
||||
return "Not Selected"
|
||||
}
|
||||
|
||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||
resolvedIPs := route.GetResolvedIPs()
|
||||
|
||||
if len(resolvedIPs) > 0 {
|
||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||
} else {
|
||||
cmd.Printf(" Resolved IPs: -\n")
|
||||
}
|
||||
}
|
||||
|
||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := resolvedIPs[domain]; exists {
|
||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
} else if appendFlag {
|
||||
req.Append = true
|
||||
}
|
||||
|
||||
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Routes selected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
}
|
||||
|
||||
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Routes deselected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -24,12 +22,8 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
name := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "Netbird"
|
||||
}
|
||||
return &service.Config{
|
||||
Name: name,
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Option: make(service.KeyValue),
|
||||
|
||||
@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
svcConfig.Option["OnFailure"] = "restart"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||
|
||||
@@ -24,7 +24,7 @@ var (
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh",
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"You can verify the connection by running:\n\n" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -34,7 +37,9 @@ type peerStateDetailOutput struct {
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
@@ -72,19 +77,28 @@ type iceCandidateType struct {
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type nsServerGroupStateOutput struct {
|
||||
Servers []string `json:"servers" yaml:"servers"`
|
||||
Domains []string `json:"domains" yaml:"domains"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -133,9 +147,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx, cmd)
|
||||
resp, err := getStatus(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -168,7 +182,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
case yamlFlag:
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false)
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -180,7 +194,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
@@ -189,7 +203,7 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -268,6 +282,13 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
if anonymizeFlag {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
anonymizeOverview(anonymizer, &overview)
|
||||
}
|
||||
|
||||
return overview
|
||||
@@ -299,6 +320,19 @@ func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||
Servers: pbNsGroupServer.GetServers(),
|
||||
Domains: pbNsGroupServer.GetDomains(),
|
||||
Enabled: pbNsGroupServer.GetEnabled(),
|
||||
Error: pbNsGroupServer.GetError(),
|
||||
})
|
||||
}
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
localICE := ""
|
||||
@@ -351,7 +385,9 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetRoutes(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
@@ -401,8 +437,7 @@ func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool) string {
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
@@ -438,7 +473,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relayAvailableString string
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
@@ -447,15 +482,46 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
relayAvailableString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
|
||||
relayAvailableString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
routes := "-"
|
||||
if len(overview.Routes) > 0 {
|
||||
sort.Strings(overview.Routes)
|
||||
routes = strings.Join(overview.Routes, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
}
|
||||
errorString := ""
|
||||
if nsServerGroup.Error != "" {
|
||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||
errorString = strings.TrimSpace(errorString)
|
||||
}
|
||||
|
||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||
if domainsString == "" {
|
||||
domainsString = "." // Show "." for the default zone
|
||||
}
|
||||
dnsServersString += fmt.Sprintf(
|
||||
"\n [%s] for [%s] is %s%s",
|
||||
strings.Join(nsServerGroup.Servers, ", "),
|
||||
domainsString,
|
||||
enabled,
|
||||
errorString,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
@@ -465,26 +531,41 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
goarm := ""
|
||||
if goarch == "arm" {
|
||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"Daemon version: %s\n"+
|
||||
"OS: %s\n"+
|
||||
"Daemon version: %s\n"+
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"Relays: %s\n"+
|
||||
"Nameservers: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relayAvailableString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
@@ -492,7 +573,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
summary := parseGeneralSummary(overview, true, true)
|
||||
summary := parseGeneralSummary(overview, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
@@ -529,15 +610,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
lastStatusUpdate := "-"
|
||||
if !peerState.LastStatusUpdate.IsZero() {
|
||||
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
lastWireGuardHandshake := "-"
|
||||
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
|
||||
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
@@ -556,6 +628,12 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
}
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(peerState.Routes) > 0 {
|
||||
sort.Strings(peerState.Routes)
|
||||
routes = strings.Join(peerState.Routes, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
@@ -569,7 +647,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n",
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n"+
|
||||
" Latency: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
@@ -580,11 +660,13 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
lastStatusUpdate,
|
||||
lastWireGuardHandshake,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
peerState.Latency.String(),
|
||||
)
|
||||
|
||||
peersString += peerString
|
||||
@@ -638,3 +720,144 @@ func toIEC(b int64) string {
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
count := 0
|
||||
for _, server := range dnsServers {
|
||||
if server.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
||||
func timeAgo(t time.Time) string {
|
||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
||||
return "-"
|
||||
}
|
||||
duration := time.Since(t)
|
||||
switch {
|
||||
case duration < time.Second:
|
||||
return "Now"
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Seconds())
|
||||
if seconds == 1 {
|
||||
return "1 second ago"
|
||||
}
|
||||
return fmt.Sprintf("%d seconds ago", seconds)
|
||||
case duration < time.Hour:
|
||||
minutes := int(duration.Minutes())
|
||||
seconds := int(duration.Seconds()) % 60
|
||||
if minutes == 1 {
|
||||
if seconds == 1 {
|
||||
return "1 minute, 1 second ago"
|
||||
} else if seconds > 0 {
|
||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
||||
}
|
||||
return "1 minute ago"
|
||||
}
|
||||
if seconds > 0 {
|
||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d minutes ago", minutes)
|
||||
case duration < 24*time.Hour:
|
||||
hours := int(duration.Hours())
|
||||
minutes := int(duration.Minutes()) % 60
|
||||
if hours == 1 {
|
||||
if minutes == 1 {
|
||||
return "1 hour, 1 minute ago"
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
||||
}
|
||||
return "1 hour ago"
|
||||
}
|
||||
if minutes > 0 {
|
||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%d hours ago", hours)
|
||||
}
|
||||
|
||||
days := int(duration.Hours()) / 24
|
||||
hours := int(duration.Hours()) % 24
|
||||
if days == 1 {
|
||||
if hours == 1 {
|
||||
return "1 day, 1 hour ago"
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
||||
}
|
||||
return "1 day ago"
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
||||
}
|
||||
return fmt.Sprintf("%d days ago", days)
|
||||
}
|
||||
|
||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
||||
}
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
anonymizePeerDetail(a, &peer)
|
||||
overview.Peers.Details[i] = peer
|
||||
}
|
||||
|
||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
||||
|
||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
||||
for i, detail := range overview.Relays.Details {
|
||||
detail.URI = a.AnonymizeURI(detail.URI)
|
||||
detail.Error = a.AnonymizeString(detail.Error)
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
for j, ns := range nsGroup.Servers {
|
||||
host, port, err := net.SplitHostPort(ns)
|
||||
if err == nil {
|
||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
@@ -3,11 +3,14 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -42,6 +45,10 @@ var resp = &proto.StatusResponse{
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
@@ -58,6 +65,7 @@ var resp = &proto.StatusResponse{
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||
BytesRx: 2000,
|
||||
BytesTx: 1000,
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
@@ -87,6 +95,31 @@ var resp = &proto.StatusResponse{
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
DnsServers: []*proto.NSGroupState{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
@@ -116,6 +149,10 @@ var overview = statusOutputOverview{
|
||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||
TransferReceived: 200,
|
||||
TransferSent: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
@@ -136,6 +173,7 @@ var overview = statusOutputOverview{
|
||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||
TransferReceived: 2000,
|
||||
TransferSent: 1000,
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -171,6 +209,31 @@ var overview = statusOutputOverview{
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []nsServerGroupStateOutput{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -232,7 +295,11 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
"quantumResistance":false
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"fqdn": "peer-2.awesome-domain.com",
|
||||
@@ -253,7 +320,9 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"quantumResistance":false
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": null
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -289,8 +358,33 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance":false,
|
||||
"quantumResistancePermissive":false
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
"8.8.8.8:53"
|
||||
],
|
||||
"domains": null,
|
||||
"enabled": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"servers": [
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53"
|
||||
],
|
||||
"domains": [
|
||||
"example.com",
|
||||
"example.net"
|
||||
],
|
||||
"enabled": false,
|
||||
"error": "timeout"
|
||||
}
|
||||
]
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
@@ -324,7 +418,10 @@ func TestParsingToYAML(t *testing.T) {
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
@@ -341,7 +438,9 @@ func TestParsingToYAML(t *testing.T) {
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
@@ -368,15 +467,37 @@ usesKernelInterface: true
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
domains: []
|
||||
enabled: true
|
||||
error: ""
|
||||
- servers:
|
||||
- 1.1.1.1:53
|
||||
- 2.2.2.2:53
|
||||
domains:
|
||||
- example.com
|
||||
- example.net
|
||||
enabled: false
|
||||
error: timeout
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
// Calculate time ago based on the fixture dates
|
||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail :=
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
@@ -387,10 +508,12 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Direct: true
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Last connection update: 2001-01-01 01:01:01
|
||||
Last WireGuard handshake: 2001-01-01 01:01:02
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
Latency: 10ms
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.102
|
||||
@@ -401,41 +524,50 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Direct: false
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Last connection update: 2002-02-02 02:02:02
|
||||
Last WireGuard handshake: 2002-02-02 02:02:03
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
Latency: 10ms
|
||||
|
||||
OS: %s/%s
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
CLI version: %s
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
[stun:my-awesome-stun.com:3478] is Available
|
||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||
Nameservers:
|
||||
[8.8.8.8:53] for [.] is Available
|
||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false, false)
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString :=
|
||||
`Daemon version: 0.14.1
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
Relays: 1/2 Available
|
||||
Nameservers: 1/2 Available
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -449,3 +581,31 @@ func TestParsingOfIP(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||
}
|
||||
|
||||
func TestTimeAgo(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input time.Time
|
||||
expected string
|
||||
}{
|
||||
{"Now", now, "Now"},
|
||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
||||
{"Zero time", time.Time{}, "-"},
|
||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := timeAgo(tc.input)
|
||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
@@ -51,7 +56,10 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
|
||||
srv, err := sig.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
panic(err)
|
||||
@@ -68,22 +76,24 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -98,7 +108,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
func startClientDaemon(
|
||||
t *testing.T, ctx context.Context, managementURL, configPath string,
|
||||
t *testing.T, ctx context.Context, _, configPath string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -40,6 +42,12 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -83,11 +91,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
@@ -114,6 +123,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -130,6 +143,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@@ -145,11 +162,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
return connectClient.Run()
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -190,6 +208,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -224,6 +243,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.WireguardPort = &wp
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
loginRequest.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
30
client/errors/errors.go
Normal file
30
client/errors/errors.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
||||
func FormatErrorOrNil(err *multierror.Error) error {
|
||||
if err != nil {
|
||||
err.ErrorFormat = formatError
|
||||
}
|
||||
return err.ErrorOrNil()
|
||||
}
|
||||
@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Debug("creating an iptables firewall manager")
|
||||
log.Info("creating an iptables firewall manager")
|
||||
fm, errFw = nbiptables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||
}
|
||||
case NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager")
|
||||
log.Info("creating an nftables firewall manager")
|
||||
fm, errFw = nbnftables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||
}
|
||||
default:
|
||||
errFw = fmt.Errorf("no firewall manager found")
|
||||
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
@@ -85,16 +85,58 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
useIPTABLES := false
|
||||
var iptablesChains []string
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err == nil && isIptablesClientAvailable(ip) {
|
||||
major, minor, _ := ip.GetIptablesVersion()
|
||||
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
|
||||
if major < 1 || (major == 1 && minor < 8) {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
useIPTABLES = true
|
||||
|
||||
iptablesChains, err = ip.ListChains("filter")
|
||||
if err != nil {
|
||||
log.Errorf("failed to list iptables chains: %s", err)
|
||||
useIPTABLES = false
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return UNKNOWN
|
||||
nf := nftables.Conn{}
|
||||
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
if !useIPTABLES {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
// search for chains where table is filter
|
||||
// if we find one, we assume that nftables manager can be used with iptables
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name == "filter" {
|
||||
return NFTABLES
|
||||
}
|
||||
}
|
||||
|
||||
// check tables for the following constraints:
|
||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||
// 4. if we find an error we log and continue with iptables check
|
||||
nbTablesList, err := nf.ListTables()
|
||||
switch {
|
||||
case err == nil && len(iptablesChains) > 0:
|
||||
return IPTABLES
|
||||
case err == nil && len(nbTablesList) != 1:
|
||||
return NFTABLES
|
||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||
return IPTABLES
|
||||
case err != nil:
|
||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||
}
|
||||
}
|
||||
if isIptablesClientAvailable(ip) {
|
||||
|
||||
if useIPTABLES {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
|
||||
@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -87,12 +87,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts an iptable rule
|
||||
// insertRoutingRule inserts an iptables rule
|
||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
@@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
@@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
@@ -326,9 +334,33 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification with comment identifier
|
||||
func genRuleSpec(jump, id, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
}
|
||||
|
||||
func getIptablesRuleType(table string) string {
|
||||
|
||||
@@ -51,14 +51,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
Destination: "100.100.100.0/24",
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
@@ -92,7 +90,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
@@ -103,7 +101,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
@@ -114,7 +112,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
@@ -130,7 +128,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
@@ -167,25 +165,25 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
return m.router.AddRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
|
||||
@@ -22,6 +22,8 @@ const (
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
@@ -126,6 +128,22 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
@@ -138,28 +156,28 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
|
||||
@@ -64,15 +64,18 @@ func manageFirewallRule(ruleName string, action action, extraArgs ...string) err
|
||||
if action == addRule {
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
netshCmd := GetSystem32Command("netsh")
|
||||
cmd := exec.Command(netshCmd, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func isWindowsFirewallReachable() bool {
|
||||
args := []string{"advfirewall", "show", "allprofiles", "state"}
|
||||
cmd := exec.Command("netsh", args...)
|
||||
|
||||
netshCmd := GetSystem32Command("netsh")
|
||||
|
||||
cmd := exec.Command(netshCmd, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
|
||||
_, err := cmd.Output()
|
||||
@@ -87,8 +90,23 @@ func isWindowsFirewallReachable() bool {
|
||||
func isFirewallRuleActive(ruleName string) bool {
|
||||
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
netshCmd := GetSystem32Command("netsh")
|
||||
|
||||
cmd := exec.Command(netshCmd, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
_, err := cmd.Output()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
|
||||
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
|
||||
func GetSystem32Command(command string) string {
|
||||
_, err := exec.LookPath(command)
|
||||
if err == nil {
|
||||
return command
|
||||
}
|
||||
|
||||
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
|
||||
|
||||
return "C:\\windows\\system32\\" + command + ".exe"
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ type HTTPClient interface {
|
||||
}
|
||||
|
||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||
type AuthFlowInfo struct {
|
||||
type AuthFlowInfo struct { //nolint:revive
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
|
||||
@@ -5,12 +5,17 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
@@ -30,8 +35,10 @@ const (
|
||||
DefaultAdminURL = "https://app.netbird.io:443"
|
||||
)
|
||||
|
||||
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
||||
var defaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
@@ -46,7 +53,10 @@ type ConfigInput struct {
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -58,6 +68,7 @@ type Config struct {
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
@@ -88,6 +99,9 @@ type Config struct {
|
||||
// DisableAutoConnect determines whether the client should not start with the service
|
||||
// it's set to false by default due to backwards compatibility
|
||||
DisableAutoConnect bool
|
||||
|
||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@@ -97,6 +111,14 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// initialize through apply() without changes
|
||||
if changed, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, err
|
||||
} else if changed {
|
||||
if err = WriteOutConfig(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -149,78 +171,15 @@ func WriteOutConfig(path string, config *Config) error {
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
wgKey := generateKey()
|
||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
SSHKey: string(pem),
|
||||
PrivateKey: wgKey,
|
||||
IFaceBlackList: []string{},
|
||||
DisableIPv6Discovery: false,
|
||||
NATExternalIPs: input.NATExternalIPs,
|
||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
||||
ServerSSHAllowed: util.False(),
|
||||
DisableAutoConnect: false,
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
}
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
if _, err := config.apply(input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.ManagementURL = defaultManagementURL
|
||||
if input.ManagementURL != "" {
|
||||
URL, err := parseURL("Management URL", input.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.ManagementURL = URL
|
||||
}
|
||||
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
if input.WireguardPort != nil {
|
||||
config.WgPort = *input.WireguardPort
|
||||
}
|
||||
|
||||
config.WgIface = iface.WgInterfaceDefault
|
||||
if input.InterfaceName != nil {
|
||||
config.WgIface = *input.InterfaceName
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil {
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil {
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil {
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
}
|
||||
|
||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.AdminURL = defaultAdminURL
|
||||
if input.AdminURL != "" {
|
||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.AdminURL = newURL
|
||||
}
|
||||
|
||||
config.IFaceBlackList = defaultInterfaceBlacklist
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -231,97 +190,12 @@ func update(input ConfigInput) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refresh := false
|
||||
|
||||
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
|
||||
log.Infof("new Management URL provided, updated to %s (old value %s)",
|
||||
input.ManagementURL, config.ManagementURL)
|
||||
newURL, err := parseURL("Management URL", input.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.ManagementURL = newURL
|
||||
refresh = true
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
|
||||
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
|
||||
input.AdminURL, config.AdminURL)
|
||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.AdminURL = newURL
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
log.Infof("new pre-shared key provided, replacing old key")
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.SSHKey = string(pem)
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.WireguardPort != nil {
|
||||
config.WgPort = *input.WireguardPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil {
|
||||
config.WgIface = *input.InterfaceName
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
||||
config.NATExternalIPs = input.NATExternalIPs
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.CustomDNSAddress != nil {
|
||||
config.CustomDNSAddress = string(input.CustomDNSAddress)
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil {
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil {
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.DisableAutoConnect != nil {
|
||||
config.DisableAutoConnect = *input.DisableAutoConnect
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.ServerSSHAllowed == nil {
|
||||
config.ServerSSHAllowed = util.True()
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if refresh {
|
||||
// since we have new management URL, we need to update config file
|
||||
if updated {
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -330,6 +204,190 @@ func update(input ConfigInput) (*Config, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
if config.ManagementURL == nil {
|
||||
log.Infof("using default Management URL %s", DefaultManagementURL)
|
||||
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
|
||||
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
|
||||
input.ManagementURL, config.ManagementURL.String())
|
||||
URL, err := parseURL("Management URL", input.ManagementURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
config.ManagementURL = URL
|
||||
updated = true
|
||||
} else if config.ManagementURL == nil {
|
||||
log.Infof("using default Management URL %s", DefaultManagementURL)
|
||||
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
if config.AdminURL == nil {
|
||||
log.Infof("using default Admin URL %s", DefaultManagementURL)
|
||||
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
|
||||
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
|
||||
input.AdminURL, config.AdminURL.String())
|
||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||
if err != nil {
|
||||
return updated, err
|
||||
}
|
||||
config.AdminURL = newURL
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.PrivateKey == "" {
|
||||
log.Infof("generated new Wireguard key")
|
||||
config.PrivateKey = generateKey()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
log.Infof("generated new SSH key")
|
||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
config.SSHKey = string(pem)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
|
||||
log.Infof("updating Wireguard port %d (old value %d)",
|
||||
*input.WireguardPort, config.WgPort)
|
||||
config.WgPort = *input.WireguardPort
|
||||
updated = true
|
||||
} else if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||
log.Infof("updating Wireguard interface %#v (old value %#v)",
|
||||
*input.InterfaceName, config.WgIface)
|
||||
config.WgIface = *input.InterfaceName
|
||||
updated = true
|
||||
} else if config.WgIface == "" {
|
||||
config.WgIface = iface.WgInterfaceDefault
|
||||
log.Infof("using default Wireguard interface %s", config.WgIface)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
|
||||
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
|
||||
strings.Join(input.NATExternalIPs, " "),
|
||||
strings.Join(config.NATExternalIPs, " "))
|
||||
config.NATExternalIPs = input.NATExternalIPs
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
|
||||
log.Infof("new pre-shared key provided, replacing old key")
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
|
||||
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
|
||||
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
|
||||
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||
config.NetworkMonitor = input.NetworkMonitor
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.NetworkMonitor == nil {
|
||||
// enable network monitoring by default on windows and darwin clients
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
|
||||
enabled := true
|
||||
config.NetworkMonitor = &enabled
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
|
||||
log.Infof("updating custom DNS address %#v (old value %#v)",
|
||||
string(input.CustomDNSAddress), config.CustomDNSAddress)
|
||||
config.CustomDNSAddress = string(input.CustomDNSAddress)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if len(config.IFaceBlackList) == 0 {
|
||||
log.Infof("filling in interface blacklist with defaults: [ %s ]",
|
||||
strings.Join(defaultInterfaceBlacklist, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if len(input.ExtraIFaceBlackList) > 0 {
|
||||
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
|
||||
log.Infof("adding new entry to interface blacklist: %s", iFace)
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
|
||||
if *input.DisableAutoConnect {
|
||||
log.Infof("turning off automatic connection on startup")
|
||||
} else {
|
||||
log.Infof("enabling automatic connection on startup")
|
||||
}
|
||||
config.DisableAutoConnect = *input.DisableAutoConnect
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
||||
if *input.ServerSSHAllowed {
|
||||
log.Infof("enabling SSH server")
|
||||
} else {
|
||||
log.Infof("disabling SSH server")
|
||||
}
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
updated = true
|
||||
} else if config.ServerSSHAllowed == nil {
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
config.DNSRouteInterval = *input.DNSRouteInterval
|
||||
updated = true
|
||||
} else if config.DNSRouteInterval == 0 {
|
||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||
updated = true
|
||||
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
@@ -384,7 +442,6 @@ func configFileIsExists(path string) bool {
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -18,7 +18,6 @@ func TestGetConfig(t *testing.T) {
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -86,6 +85,26 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
|
||||
}
|
||||
|
||||
func TestExtraIFaceBlackList(t *testing.T) {
|
||||
extraIFaceBlackList := []string{"eth1"}
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: path,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, config.IFaceBlackList, "eth1")
|
||||
readConf, err := util.ReadJson(path, config)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1")
|
||||
}
|
||||
|
||||
func TestHiddenPreSharedKey(t *testing.T) {
|
||||
hidden := "**********"
|
||||
samplePreSharedKey := "mysecretpresharedkey"
|
||||
@@ -111,7 +130,6 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
||||
ConfigPath: cfgFile,
|
||||
PreSharedKey: tt.preSharedKey,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get cfg: %s", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -27,29 +31,45 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *Config
|
||||
statusRecorder *peer.Status
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
}
|
||||
|
||||
// RunClientWithProbes runs the client's main logic with probes attached
|
||||
func RunClientWithProbes(
|
||||
func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run() error {
|
||||
return c.run(MobileDependency{}, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
func RunClientMobile(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
func (c *ConnectClient) RunOnAndroid(
|
||||
tunAdapter iface.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
@@ -64,40 +84,43 @@ func RunClientMobile(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
return c.run(mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func RunClientiOS(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
debug.SetGCPercent(5)
|
||||
|
||||
mobileDependency := MobileDependency{
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
return c.run(mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
func (c *ConnectClient) run(
|
||||
mobileDependency MobileDependency,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
|
||||
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
|
||||
log.Errorf("checking unclean shutdown error: %s", err)
|
||||
}
|
||||
|
||||
@@ -111,7 +134,7 @@ func runClient(
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
state := CtxGetState(ctx)
|
||||
state := CtxGetState(c.ctx)
|
||||
defer func() {
|
||||
s, err := state.Status()
|
||||
if err != nil || s != StatusNeedsLogin {
|
||||
@@ -120,49 +143,49 @@ func runClient(
|
||||
}()
|
||||
|
||||
wrapErr := state.Wrap
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
if c.config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer statusRecorder.ClientStop()
|
||||
defer c.statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
state.Set(StatusConnecting)
|
||||
|
||||
engineCtx, cancel := context.WithCancel(ctx)
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
defer func() {
|
||||
statusRecorder.MarkManagementDisconnected(state.err)
|
||||
statusRecorder.CleanLocalPeerState()
|
||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
||||
c.statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||
mgmClient.SetConnStateListener(mgmNotifier)
|
||||
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
@@ -180,7 +203,7 @@ func runClient(
|
||||
}
|
||||
return wrapErr(err)
|
||||
}
|
||||
statusRecorder.MarkManagementConnected()
|
||||
c.statusRecorder.MarkManagementConnected()
|
||||
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
@@ -189,18 +212,18 @@ func runClient(
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
|
||||
statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
signalURL := fmt.Sprintf("%s://%s",
|
||||
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||
)
|
||||
|
||||
statusRecorder.UpdateSignalAddress(signalURL)
|
||||
c.statusRecorder.UpdateSignalAddress(signalURL)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected(nil)
|
||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
statusRecorder.MarkSignalDisconnected(state.err)
|
||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
@@ -216,35 +239,40 @@ func runClient(
|
||||
}
|
||||
}()
|
||||
|
||||
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
|
||||
signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder)
|
||||
signalClient.SetConnStateListener(signalNotifier)
|
||||
|
||||
statusRecorder.MarkSignalConnected()
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
err = engine.Start()
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
err = c.engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
<-engineCtx.Done()
|
||||
statusRecorder.ClientTeardown()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
err = engine.Stop()
|
||||
err = c.engine.Stop()
|
||||
if err != nil {
|
||||
log.Errorf("failed stopping engine %v", err)
|
||||
return wrapErr(err)
|
||||
@@ -259,7 +287,7 @@ func runClient(
|
||||
return nil
|
||||
}
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
@@ -271,8 +299,20 @@ func runClient(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Engine() *Engine {
|
||||
var e *Engine
|
||||
c.engineMutex.Lock()
|
||||
e = c.engine
|
||||
c.engineMutex.Unlock()
|
||||
return e
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
@@ -280,12 +320,14 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -296,6 +338,15 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
engineConf.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
port, err := freePort(config.WgPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port != config.WgPort {
|
||||
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
|
||||
}
|
||||
engineConf.WgPort = port
|
||||
|
||||
return engineConf, nil
|
||||
}
|
||||
|
||||
@@ -345,3 +396,20 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
notifier, _ := sri.(signal.ConnStateNotifier)
|
||||
return notifier
|
||||
}
|
||||
|
||||
func freePort(start int) (int, error) {
|
||||
addr := net.UDPAddr{}
|
||||
if start == 0 {
|
||||
start = iface.DefaultWgPort
|
||||
}
|
||||
for x := start; x <= 65535; x++ {
|
||||
addr.Port = x
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Close()
|
||||
return x, nil
|
||||
}
|
||||
return 0, errors.New("no free ports")
|
||||
}
|
||||
|
||||
57
client/internal/connect_test.go
Normal file
57
client/internal/connect_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_freePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port int
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available",
|
||||
port: 51820,
|
||||
want: 51820,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "notavailable",
|
||||
port: 51830,
|
||||
want: 51831,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "noports",
|
||||
port: 65535,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := freePort(tt.port)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("freePort() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}
|
||||
}
|
||||
6
client/internal/dns/consts_freebsd.go
Normal file
6
client/internal/dns/consts_freebsd.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
||||
)
|
||||
8
client/internal/dns/consts_linux.go
Normal file
8
client/internal/dns/consts_linux.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -14,6 +15,9 @@ const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
)
|
||||
|
||||
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
|
||||
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
|
||||
|
||||
type resolvConf struct {
|
||||
nameServers []string
|
||||
searchDomains []string
|
||||
@@ -103,3 +107,62 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||
}
|
||||
return rconf, nil
|
||||
}
|
||||
|
||||
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
|
||||
// otherwise it adds a new option with timeout and attempts.
|
||||
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
|
||||
configs := make([]string, len(input))
|
||||
copy(configs, input)
|
||||
|
||||
for i, config := range configs {
|
||||
if strings.HasPrefix(config, "options") {
|
||||
config = strings.ReplaceAll(config, "rotate", "")
|
||||
config = strings.Join(strings.Fields(config), " ")
|
||||
|
||||
if strings.Contains(config, "timeout:") {
|
||||
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
|
||||
} else {
|
||||
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
|
||||
}
|
||||
|
||||
if strings.Contains(config, "attempts:") {
|
||||
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
|
||||
} else {
|
||||
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
|
||||
}
|
||||
|
||||
configs[i] = config
|
||||
return configs
|
||||
}
|
||||
}
|
||||
|
||||
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
|
||||
}
|
||||
|
||||
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
||||
// and writes the file back to the original location
|
||||
func removeFirstNbNameserver(filename, nameserverIP string) error {
|
||||
resolvConf, err := parseResolvConfFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
||||
}
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", filename, err)
|
||||
}
|
||||
|
||||
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
|
||||
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
||||
|
||||
stat, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", filename, err)
|
||||
}
|
||||
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
||||
return fmt.Errorf("write %s: %w", filename, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseResolvConf(t *testing.T) {
|
||||
@@ -172,3 +174,131 @@ nameserver 192.168.0.1
|
||||
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareOptionsWithTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
others []string
|
||||
timeout int
|
||||
attempts int
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Append new options with timeout and attempts",
|
||||
others: []string{"some config"},
|
||||
timeout: 2,
|
||||
attempts: 2,
|
||||
expected: []string{"some config", "options timeout:2 attempts:2"},
|
||||
},
|
||||
{
|
||||
name: "Modify existing options to exclude rotate and include timeout and attempts",
|
||||
others: []string{"some config", "options rotate someother"},
|
||||
timeout: 3,
|
||||
attempts: 2,
|
||||
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
|
||||
},
|
||||
{
|
||||
name: "Existing options with timeout and attempts are updated",
|
||||
others: []string{"some config", "options timeout:4 attempts:3"},
|
||||
timeout: 5,
|
||||
attempts: 4,
|
||||
expected: []string{"some config", "options timeout:5 attempts:4"},
|
||||
},
|
||||
{
|
||||
name: "Modify existing options, add missing attempts before timeout",
|
||||
others: []string{"some config", "options timeout:4"},
|
||||
timeout: 4,
|
||||
attempts: 3,
|
||||
expected: []string{"some config", "options attempts:3 timeout:4"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirstNbNameserver(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
content string
|
||||
ipToRemove string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Unrelated nameservers with comments and options",
|
||||
content: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
},
|
||||
{
|
||||
name: "First nameserver matches",
|
||||
content: `search example.com
|
||||
nameserver 9.9.9.9
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `search example.com
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
},
|
||||
{
|
||||
name: "Target IP not the first nameserver",
|
||||
// nolint:dupword
|
||||
content: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
// nolint:dupword
|
||||
expected: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
},
|
||||
{
|
||||
name: "Only nameserver matches",
|
||||
content: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "resolv.conf")
|
||||
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tempFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -23,10 +24,16 @@ const (
|
||||
fileMaxNumberOfSearchDomains = 6
|
||||
)
|
||||
|
||||
const (
|
||||
dnsFailoverTimeout = 4 * time.Second
|
||||
dnsFailoverAttempts = 1
|
||||
)
|
||||
|
||||
type fileConfigurator struct {
|
||||
repair *repair
|
||||
|
||||
originalPerms os.FileMode
|
||||
originalPerms os.FileMode
|
||||
nbNameserverIP string
|
||||
}
|
||||
|
||||
func newFileConfigurator() (hostManager, error) {
|
||||
@@ -40,31 +47,27 @@ func (f *fileConfigurator) supportCustomPort() bool {
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
backupFileExist := false
|
||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||
if err == nil {
|
||||
backupFileExist = true
|
||||
}
|
||||
|
||||
backupFileExist := f.isBackupFileExist()
|
||||
if !config.RouteAll {
|
||||
if backupFileExist {
|
||||
err = f.restore()
|
||||
f.repair.stopWatchFileChanges()
|
||||
err := f.restore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
|
||||
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
}
|
||||
|
||||
if !backupFileExist {
|
||||
err = f.backup()
|
||||
err := f.backup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbSearchDomains := searchDomains(config)
|
||||
nbNameserverIP := config.ServerIP
|
||||
f.nbNameserverIP = config.ServerIP
|
||||
|
||||
resolvConf, err := parseBackupResolvConf()
|
||||
if err != nil {
|
||||
@@ -73,11 +76,11 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
|
||||
f.repair.stopWatchFileChanges()
|
||||
|
||||
err = f.updateConfig(nbSearchDomains, nbNameserverIP, resolvConf)
|
||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.repair.watchFileChanges(nbSearchDomains, nbNameserverIP)
|
||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -85,10 +88,11 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
|
||||
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||
|
||||
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
nameServers,
|
||||
cfg.others)
|
||||
options)
|
||||
|
||||
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
||||
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
||||
@@ -131,7 +135,12 @@ func (f *fileConfigurator) backup() error {
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restore() error {
|
||||
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
@@ -157,7 +166,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
||||
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
||||
// not a valid first nameserver -> restore
|
||||
if err != nil {
|
||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[1], err)
|
||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
@@ -171,6 +180,11 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) isBackupFileExist() bool {
|
||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func restoreResolvConfFile() error {
|
||||
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -65,7 +65,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("discovered mode is: %s", osManager)
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
return newHostManagerFromType(wgInterface, osManager)
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
@@ -116,16 +116,10 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
var value string
|
||||
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
|
||||
if err == nil {
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return systemdManager, nil
|
||||
}
|
||||
}
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
63
client/internal/dns/hosts_dns_holder.go
Normal file
63
client/internal/dns/hosts_dns_holder.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type hostsDNSHolder struct {
|
||||
unprotectedDNSList map[string]struct{}
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func newHostsDNSHolder() *hostsDNSHolder {
|
||||
return &hostsDNSHolder{
|
||||
unprotectedDNSList: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hostsDNSHolder) set(list []string) {
|
||||
h.mutex.Lock()
|
||||
h.unprotectedDNSList = make(map[string]struct{})
|
||||
for _, dns := range list {
|
||||
dnsAddr, err := h.normalizeAddress(dns)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
h.unprotectedDNSList[dnsAddr] = struct{}{}
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *hostsDNSHolder) get() map[string]struct{} {
|
||||
h.mutex.RLock()
|
||||
l := h.unprotectedDNSList
|
||||
h.mutex.RUnlock()
|
||||
return l
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (h *hostsDNSHolder) isContain(upstream string) bool {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
_, ok := h.unprotectedDNSList[upstream]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
|
||||
a, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if a.Is4() {
|
||||
return fmt.Sprintf("%s:53", addr), nil
|
||||
} else {
|
||||
return fmt.Sprintf("[%s]:53", addr), nil
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
response := d.lookupRecord(r)
|
||||
if response != nil {
|
||||
replyMessage.Answer = append(replyMessage.Answer, response)
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
err := w.WriteMsg(replyMessage)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -53,10 +53,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
||||
searchDomainList := searchDomains(config)
|
||||
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
||||
|
||||
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||
r.othersConfigs)
|
||||
options)
|
||||
|
||||
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@@ -52,13 +55,14 @@ type DefaultServer struct {
|
||||
currentConfig HostDNSConfig
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
hostsDnsList []string
|
||||
hostsDnsListLock sync.Mutex
|
||||
permanent bool
|
||||
hostsDNSHolder *hostsDNSHolder
|
||||
|
||||
// make sense on mobile only
|
||||
searchDomainNotifier *notifier
|
||||
iosDnsManager IosDnsManager
|
||||
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
@@ -73,7 +77,12 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -90,15 +99,22 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
|
||||
func NewDefaultServerPermanentUpstream(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
hostsDnsList []string,
|
||||
config nbdns.Config,
|
||||
listener listener.NetworkChangeListener,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.hostsDnsList = hostsDnsList
|
||||
ds.addHostRootZone()
|
||||
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
|
||||
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
|
||||
@@ -108,13 +124,18 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface,
|
||||
}
|
||||
|
||||
// NewDefaultServerIos returns a new dns server. It optimized for ios
|
||||
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
@@ -124,7 +145,9 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
}
|
||||
|
||||
return defaultServer
|
||||
@@ -180,10 +203,8 @@ func (s *DefaultServer) Stop() {
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
s.hostsDnsListLock.Lock()
|
||||
defer s.hostsDnsListLock.Unlock()
|
||||
s.hostsDNSHolder.set(hostsDnsList)
|
||||
|
||||
s.hostsDnsList = hostsDnsList
|
||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||
if ok {
|
||||
log.Debugf("on new host DNS config but skip to apply it")
|
||||
@@ -256,9 +277,15 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
var wg sync.WaitGroup
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
mux.probeAvailability()
|
||||
wg.Add(1)
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.probeAvailability()
|
||||
}(mux)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
@@ -299,6 +326,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
}
|
||||
|
||||
s.updateNSGroupStates(update.NameServerGroups)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -338,7 +367,14 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
continue
|
||||
}
|
||||
|
||||
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
}
|
||||
@@ -416,9 +452,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||
s.hostsDnsListLock.Lock()
|
||||
s.addHostRootZone()
|
||||
s.hostsDnsListLock.Unlock()
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
@@ -460,14 +494,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
) (deactivate func(error), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
deactivate = func(err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
l.Info("Temporarily deactivating nameservers group due to timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
@@ -476,6 +510,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.RouteAll = false
|
||||
s.service.DeregisterMux(nbdns.RootZone)
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
@@ -485,9 +520,17 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, err, false)
|
||||
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
@@ -506,36 +549,79 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.RouteAll = true
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) addHostRootZone() {
|
||||
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return
|
||||
}
|
||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||
for n, ua := range s.hostsDnsList {
|
||||
a, err := netip.ParseAddr(ua)
|
||||
if err != nil {
|
||||
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
|
||||
continue
|
||||
}
|
||||
|
||||
ipString := ua
|
||||
if !a.Is4() {
|
||||
ipString = fmt.Sprintf("[%s]", ua)
|
||||
}
|
||||
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
||||
handler.upstreamServers = make([]string, 0)
|
||||
for k := range s.hostsDNSHolder.get() {
|
||||
handler.upstreamServers = append(handler.upstreamServers, k)
|
||||
}
|
||||
handler.deactivate = func() {}
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
var states []peer.NSGroupState
|
||||
|
||||
for _, group := range groups {
|
||||
var servers []string
|
||||
for _, ns := range group.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
|
||||
state := peer.NSGroupState{
|
||||
ID: generateGroupKey(group),
|
||||
Servers: servers,
|
||||
Domains: group.Domains,
|
||||
// The probe will determine the state, default enabled
|
||||
Enabled: true,
|
||||
Error: nil,
|
||||
}
|
||||
states = append(states, state)
|
||||
}
|
||||
s.statusRecorder.UpdateDNSStates(states)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
|
||||
states := s.statusRecorder.GetDNSStates()
|
||||
id := generateGroupKey(nsGroup)
|
||||
for i, state := range states {
|
||||
if state.ID == id {
|
||||
states[i].Enabled = enabled
|
||||
states[i].Error = err
|
||||
break
|
||||
}
|
||||
}
|
||||
s.statusRecorder.UpdateDNSStates(states)
|
||||
}
|
||||
|
||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
var servers []string
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -38,6 +39,10 @@ func (w *mocWGIface) Address() iface.WGAddress {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) ToInterface() *net.Interface {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
@@ -260,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -274,7 +279,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -338,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
@@ -375,7 +380,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -470,7 +475,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
@@ -541,6 +546,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
statusRecorder: &peer.Status{},
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
@@ -563,7 +569,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
},
|
||||
}, nil)
|
||||
|
||||
deactivate()
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
@@ -601,7 +607,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
|
||||
var dnsList []string
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -625,7 +631,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -717,7 +723,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -748,6 +754,11 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("9.9.9.9"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Primary: false,
|
||||
@@ -790,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
20
client/internal/dns/systemd_freebsd.go
Normal file
20
client/internal/dns/systemd_freebsd.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errNotImplemented = errors.New("not implemented")
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -242,3 +242,25 @@ func getSystemdDbusProperty(property string, store any) error {
|
||||
|
||||
return v.Store(store)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
return false
|
||||
}
|
||||
|
||||
var value string
|
||||
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@@ -14,11 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
||||
|
||||
func CheckUncleanShutdown(wgIface string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@@ -5,14 +5,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -45,12 +47,13 @@ type upstreamResolverBase struct {
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
|
||||
deactivate func()
|
||||
reactivate func()
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(parentCTX)
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
@@ -58,6 +61,7 @@ func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,9 +72,17 @@ func (u *upstreamResolverBase) stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
defer u.checkUpstreamFails()
|
||||
var err error
|
||||
defer func() {
|
||||
u.checkUpstreamFails(err)
|
||||
}()
|
||||
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
@@ -81,7 +93,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
for _, upstream := range u.upstreamServers {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
@@ -132,7 +143,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// If fails count is greater that failsTillDeact, upstream resolving
|
||||
// will be disabled for reactivatePeriod, after that time period fails counter
|
||||
// will be reset and upstream will be reactivated.
|
||||
func (u *upstreamResolverBase) checkUpstreamFails() {
|
||||
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
@@ -146,7 +157,7 @@ func (u *upstreamResolverBase) checkUpstreamFails() {
|
||||
default:
|
||||
}
|
||||
|
||||
u.disable()
|
||||
u.disable(err)
|
||||
}
|
||||
|
||||
// probeAvailability tests all upstream servers simultaneously and
|
||||
@@ -165,13 +176,16 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errors *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
upstream := upstream
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := u.testNameserver(upstream); err != nil {
|
||||
err := u.testNameserver(upstream)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, err)
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
@@ -186,7 +200,7 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable()
|
||||
u.disable(errors.ErrorOrNil())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,18 +259,15 @@ func isTimeout(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) disable() {
|
||||
func (u *upstreamResolverBase) disable(err error) {
|
||||
if u.disabled {
|
||||
return
|
||||
}
|
||||
|
||||
// todo test the deactivation logic, it seems to affect the client
|
||||
if runtime.GOOS != "ios" {
|
||||
log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.deactivate()
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server string) error {
|
||||
|
||||
84
client/internal/dns/upstream_android.go
Normal file
84
client/internal/dns/upstream_android.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
*upstreamResolverBase
|
||||
hostsDNSHolder *hostsDNSHolder
|
||||
}
|
||||
|
||||
// newUpstreamResolver in Android we need to distinguish the DNS servers to available through VPN or outside of VPN
|
||||
// In case if the assigned DNS address is available only in the protected network then the resolver will time out at the
|
||||
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
c := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
hostsDNSHolder: hostsDNSHolder,
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = c
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// exchange in case of Android if the upstream is a local resolver then we do not need to mark the socket as protected.
|
||||
// In other case the DNS resolvation goes through the VPN, so we need to force to use the
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
if u.isLocalResolver(upstream) {
|
||||
return u.exchangeWithoutVPN(ctx, upstream, r)
|
||||
} else {
|
||||
return u.exchangeWithinVPN(ctx, upstream, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
upstreamExchangeClient := &dns.Client{}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
timeout := upstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
dialTimeout := timeout
|
||||
|
||||
nbDialer := nbnet.NewDialer()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
return nbDialer.Control(network, address, c)
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
}
|
||||
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Dialer: dialer,
|
||||
}
|
||||
|
||||
return upstreamExchangeClient.Exchange(r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
if u.hostsDNSHolder.isContain(upstream) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
38
client/internal/dns/upstream_general.go
Normal file
38
client/internal/dns/upstream_general.go
Normal file
@@ -0,0 +1,38 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
*upstreamResolverBase
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = nonIOS
|
||||
return nonIOS, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
upstreamExchangeClient := &dns.Client{}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
@@ -20,8 +22,15 @@ type upstreamResolverIOS struct {
|
||||
iIndex int
|
||||
}
|
||||
|
||||
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip net.IP,
|
||||
net *net.IPNet,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type upstreamResolverNonIOS struct {
|
||||
*upstreamResolverBase
|
||||
}
|
||||
|
||||
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
|
||||
nonIOS := &upstreamResolverNonIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = nonIOS
|
||||
return nonIOS, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
upstreamExchangeClient := &dns.Client{}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
@@ -131,7 +131,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
}
|
||||
|
||||
failed := false
|
||||
resolver.deactivate = func() {
|
||||
resolver.deactivate = func(error) {
|
||||
failed = true
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,17 @@
|
||||
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
|
||||
@@ -2,12 +2,15 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -21,21 +24,26 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -60,6 +68,9 @@ type EngineConfig struct {
|
||||
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
||||
WgPrivateKey wgtypes.Key
|
||||
|
||||
// NetworkMonitor is a flag to enable network monitoring
|
||||
NetworkMonitor bool
|
||||
|
||||
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
@@ -83,6 +94,8 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -93,6 +106,10 @@ type Engine struct {
|
||||
mgmClient mgm.Client
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
|
||||
// rpManager is a Rosenpass manager
|
||||
rpManager *rosenpass.Manager
|
||||
|
||||
@@ -107,9 +124,15 @@ type Engine struct {
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
TURNs []*stun.URI
|
||||
|
||||
cancel context.CancelFunc
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
clientRoutesMu sync.RWMutex
|
||||
|
||||
ctx context.Context
|
||||
clientCtx context.Context
|
||||
clientCancel context.CancelFunc
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface *iface.WGIface
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
@@ -119,6 +142,8 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
|
||||
@@ -134,6 +159,11 @@ type Engine struct {
|
||||
signalProbe *Probe
|
||||
relayProbe *Probe
|
||||
wgProbe *Probe
|
||||
|
||||
wgConnWorker sync.WaitGroup
|
||||
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -144,17 +174,18 @@ type Peer struct {
|
||||
|
||||
// NewEngine creates a new Connection Engine
|
||||
func NewEngine(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
return NewEngineWithProbes(
|
||||
ctx,
|
||||
cancel,
|
||||
clientCtx,
|
||||
clientCancel,
|
||||
signalClient,
|
||||
mgmClient,
|
||||
config,
|
||||
@@ -164,13 +195,14 @@ func NewEngine(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
checks,
|
||||
)
|
||||
}
|
||||
|
||||
// NewEngineWithProbes creates a new Connection Engine with probes attached
|
||||
func NewEngineWithProbes(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
config *EngineConfig,
|
||||
@@ -180,10 +212,12 @@ func NewEngineWithProbes(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
|
||||
return &Engine{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: signalClient,
|
||||
mgmClient: mgmClient,
|
||||
peerConns: make(map[string]*peer.Conn),
|
||||
@@ -195,11 +229,11 @@ func NewEngineWithProbes(
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
||||
mgmProbe: mgmProbe,
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
wgProbe: wgProbe,
|
||||
checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,16 +241,31 @@ func (e *Engine) Stop() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// stopping network monitor first to avoid starting the engine again
|
||||
if e.networkMonitor != nil {
|
||||
e.networkMonitor.Stop()
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = nil
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.CLose() asynchronously
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
e.wgConnWorker.Wait()
|
||||
log.Infof("stopped Netbird Engine")
|
||||
return nil
|
||||
}
|
||||
@@ -228,13 +277,21 @@ func (e *Engine) Start() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
return err
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||
return fmt.Errorf("new wg interface: %w", err)
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@@ -244,29 +301,37 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||
}
|
||||
err := e.rpManager.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
e.beforePeerHook = beforePeerHook
|
||||
e.afterPeerHook = afterPeerHook
|
||||
}
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
err = e.wgInterfaceCreate()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
||||
@@ -278,7 +343,7 @@ func (e *Engine) Start() error {
|
||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,7 +351,7 @@ func (e *Engine) Start() error {
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
@@ -296,13 +361,16 @@ func (e *Engine) Start() error {
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveProbeEvents()
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -474,6 +542,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if update.GetNetworkMap() != nil {
|
||||
// only apply new changes and ignore old ones
|
||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||
@@ -481,7 +553,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
// if checks are equal, we skip the update
|
||||
if isChecksEqual(e.checks, checks) {
|
||||
return nil
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -497,8 +589,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
} else {
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on Windows is not supported")
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
@@ -571,12 +663,19 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
err := e.mgmClient.Sync(e.handleSync)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.cancel()
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
log.Debugf("stopped receiving updates from Management Service")
|
||||
@@ -638,6 +737,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
@@ -679,14 +792,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update routes, err: %v", err)
|
||||
}
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
@@ -698,30 +803,40 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
e.dnsServer.ProbeAvailability()
|
||||
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
e.dnsServer.ProbeAvailability()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||
var prefix netip.Prefix
|
||||
if len(protoRoute.Domains) == 0 {
|
||||
var err error
|
||||
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
|
||||
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
convertedRoute := &route.Route{
|
||||
ID: protoRoute.ID,
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
NetID: protoRoute.NetID,
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
Peer: protoRoute.Peer,
|
||||
Metric: int(protoRoute.Metric),
|
||||
Masquerade: protoRoute.Masquerade,
|
||||
KeepRoute: protoRoute.KeepRoute,
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
@@ -781,6 +896,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusDisconnected,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
}
|
||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||
@@ -804,27 +920,39 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
if _, ok := e.peerConns[peerKey]; !ok {
|
||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
e.peerConns[peerKey] = conn
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
e.wgConnWorker.Add(1)
|
||||
go e.connWorker(conn, peerKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
defer e.wgConnWorker.Done()
|
||||
for {
|
||||
|
||||
// randomize starting time a bit
|
||||
min := 500
|
||||
max := 2000
|
||||
time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
|
||||
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
return
|
||||
case <-time.After(duration):
|
||||
}
|
||||
|
||||
// if peer has been removed -> give up
|
||||
if !e.peerExists(peerKey) {
|
||||
@@ -842,11 +970,12 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
err := conn.Open()
|
||||
err := conn.Open(e.ctx)
|
||||
if err != nil {
|
||||
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
||||
switch err.(type) {
|
||||
case *peer.ConnectionClosedError:
|
||||
var connectionClosedError *peer.ConnectionClosedError
|
||||
switch {
|
||||
case errors.As(err, &connectionClosedError):
|
||||
// conn has been forced to close, so we exit the loop
|
||||
return
|
||||
default:
|
||||
@@ -910,7 +1039,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
}
|
||||
@@ -957,7 +1085,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
go func() {
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(func(msg *sProto.Message) error {
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -973,8 +1101,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||
|
||||
var rosenpassPubKey []byte
|
||||
rosenpassAddr := ""
|
||||
if msg.GetBody().GetRosenpassConfig() != nil {
|
||||
@@ -997,8 +1123,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
|
||||
|
||||
var rosenpassPubKey []byte
|
||||
rosenpassAddr := ""
|
||||
if msg.GetBody().GetRosenpassConfig() != nil {
|
||||
@@ -1021,7 +1145,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
|
||||
return err
|
||||
}
|
||||
conn.OnRemoteCandidate(candidate)
|
||||
|
||||
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
||||
case sProto.Body_MODE:
|
||||
}
|
||||
|
||||
@@ -1031,7 +1156,7 @@ func (e *Engine) receiveSignalEvents() {
|
||||
// happens if signal is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.cancel()
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -1092,13 +1217,20 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if err := e.wgProxyFactory.Free(); err != nil {
|
||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||
if e.wgProxyFactory != nil {
|
||||
if err := e.wgProxyFactory.Free(); err != nil {
|
||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
@@ -1115,10 +1247,6 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Reset()
|
||||
if err != nil {
|
||||
@@ -1132,7 +1260,8 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
netMap, err := e.mgmClient.GetNetworkMap()
|
||||
info := system.GetInfo(e.ctx)
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -1161,7 +1290,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
|
||||
}
|
||||
|
||||
func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
@@ -1188,14 +1317,21 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener)
|
||||
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
||||
e.ctx,
|
||||
e.wgInterface,
|
||||
e.mobileDep.HostDNSAddresses,
|
||||
*dnsConfig,
|
||||
e.mobileDep.NetworkChangeListener,
|
||||
e.statusRecorder,
|
||||
)
|
||||
go e.mobileDep.DnsReadyListener.OnReady()
|
||||
return routes, dnsServer, nil
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
||||
return nil, dnsServer, nil
|
||||
default:
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -1203,6 +1339,31 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the current routes from the route map
|
||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
||||
e.clientRoutesMu.RLock()
|
||||
defer e.clientRoutesMu.RUnlock()
|
||||
|
||||
return maps.Clone(e.clientRoutes)
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
e.clientRoutesMu.RLock()
|
||||
defer e.clientRoutesMu.RUnlock()
|
||||
|
||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
||||
for id, v := range e.clientRoutes {
|
||||
routes[id.NetID()] = v
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetRouteManager returns the route manager
|
||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
@@ -1303,3 +1464,72 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
||||
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
|
||||
}
|
||||
|
||||
func (e *Engine) restartEngine() {
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
return
|
||||
}
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
var mu sync.Mutex
|
||||
var debounceTimer *time.Timer
|
||||
|
||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
||||
// a network change is detected, or an error occurs on start up
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
// This function is called when a network change is detected
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,10 +17,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -55,9 +58,9 @@ var (
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping TestEngine_SSH on Windows")
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping TestEngine_SSH")
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
@@ -70,12 +73,12 @@ func TestEngine_SSH(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -171,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//time.Sleep(250 * time.Millisecond)
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
@@ -209,16 +212,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@@ -227,6 +230,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -390,7 +394,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
@@ -405,7 +409,8 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -563,12 +568,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@@ -576,10 +582,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
}{}
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
input.inputSerial = updateSerial
|
||||
input.inputRoutes = newRoutes
|
||||
return testCase.inputErr
|
||||
return nil, nil, testCase.inputErr
|
||||
},
|
||||
}
|
||||
|
||||
@@ -596,8 +602,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -732,17 +738,19 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
return nil
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
return nil, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -803,13 +811,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal()
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(dir)
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1001,10 +1009,14 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
@@ -1012,7 +1024,9 @@ func startSignal() (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
@@ -1023,7 +1037,9 @@ func startSignal() (*grpc.Server, string, error) {
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
@@ -1040,22 +1056,25 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, err := server.NewStoreFromJson(config.Datadir, nil)
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
||||
}
|
||||
|
||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||
return err
|
||||
|
||||
21
client/internal/networkmonitor/monitor.go
Normal file
21
client/internal/networkmonitor/monitor.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrStopped = errors.New("monitor has been stopped")
|
||||
|
||||
// NetworkMonitor watches for changes in network configuration.
|
||||
type NetworkMonitor struct {
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new network monitor.
|
||||
func New() *NetworkMonitor {
|
||||
return &NetworkMonitor{}
|
||||
}
|
||||
95
client/internal/networkmonitor/monitor_bsd.go
Normal file
95
client/internal/networkmonitor/monitor_bsd.go
Normal file
@@ -0,0 +1,95 @@
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil {
|
||||
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !route.Dst.Addr().IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
82
client/internal/networkmonitor/monitor_generic.go
Normal file
82
client/internal/networkmonitor/monitor_generic.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
nw.mu.Lock()
|
||||
ctx, nw.cancel = context.WithCancel(ctx)
|
||||
nw.mu.Unlock()
|
||||
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
return nil
|
||||
}
|
||||
|
||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||
}
|
||||
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the network monitor.
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
nw.mu.Lock()
|
||||
defer nw.mu.Unlock()
|
||||
|
||||
if nw.cancel != nil {
|
||||
nw.cancel()
|
||||
nw.wg.Wait()
|
||||
}
|
||||
}
|
||||
57
client/internal/networkmonitor/monitor_linux.go
Normal file
57
client/internal/networkmonitor/monitor_linux.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build !android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
routeChan := make(chan netlink.RouteUpdate)
|
||||
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to route updates: %v", err)
|
||||
}
|
||||
|
||||
log.Info("Network monitor: started")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
|
||||
// handle route changes
|
||||
case route := <-routeChan:
|
||||
// default route and main table
|
||||
if route.Dst != nil || route.Table != syscall.RT_TABLE_MAIN {
|
||||
continue
|
||||
}
|
||||
switch route.Type {
|
||||
// triggered on added/replaced routes
|
||||
case syscall.RTM_NEWROUTE:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
client/internal/networkmonitor/monitor_mobile.go
Normal file
12
client/internal/networkmonitor/monitor_mobile.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build ios || android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import "context"
|
||||
|
||||
func (nw *NetworkMonitor) Start(context.Context, func()) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
}
|
||||
245
client/internal/networkmonitor/monitor_windows.go
Normal file
245
client/internal/networkmonitor/monitor_windows.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
const (
|
||||
unreachable = 0
|
||||
incomplete = 1
|
||||
probe = 2
|
||||
delay = 3
|
||||
stale = 4
|
||||
reachable = 5
|
||||
permanent = 6
|
||||
tbd = 7
|
||||
)
|
||||
|
||||
const interval = 10 * time.Second
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
var neighborv4, neighborv6 *systemops.Neighbor
|
||||
{
|
||||
initialNeighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
|
||||
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
|
||||
}
|
||||
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
case <-ticker.C:
|
||||
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
|
||||
if n, ok := initialNeighbors[nexthop.IP]; ok &&
|
||||
n.State != unreachable &&
|
||||
n.State != incomplete &&
|
||||
n.State != tbd {
|
||||
return &n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func changed(
|
||||
nexthopv4 systemops.Nexthop,
|
||||
neighborv4 *systemops.Neighbor,
|
||||
nexthopv6 systemops.Nexthop,
|
||||
neighborv6 *systemops.Neighbor,
|
||||
) bool {
|
||||
neighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
log.Errorf("network monitor: error fetching current neighbors: %v", err)
|
||||
return false
|
||||
}
|
||||
if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
|
||||
return true
|
||||
}
|
||||
|
||||
routes, err := getRoutes()
|
||||
if err != nil {
|
||||
log.Errorf("network monitor: error fetching current routes: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
|
||||
if !nexthop.IP.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
unspec := getUnspecifiedPrefix(nexthop.IP)
|
||||
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
|
||||
|
||||
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
|
||||
|
||||
if !foundMatchingRoute {
|
||||
logRouteChange(nexthop.IP, intf)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
|
||||
if ip.Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
var defaultRoutes []string
|
||||
foundMatchingRoute := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Destination == unspec {
|
||||
routeInfo := formatRouteInfo(r)
|
||||
defaultRoutes = append(defaultRoutes, routeInfo)
|
||||
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
|
||||
foundMatchingRoute = true
|
||||
log.Debugf("network monitor: found matching default route: %s", routeInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultRoutes, foundMatchingRoute
|
||||
}
|
||||
|
||||
func formatRouteInfo(r systemops.Route) string {
|
||||
newIntf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
newIntf = r.Interface.Name
|
||||
}
|
||||
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
|
||||
}
|
||||
|
||||
func logRouteChange(ip netip.Addr, intf *net.Interface) {
|
||||
oldIntf := "<nil>"
|
||||
if intf != nil {
|
||||
oldIntf = intf.Name
|
||||
}
|
||||
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
|
||||
if neighbor == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||
if n, ok := neighbors[nexthop.IP]; ok {
|
||||
if n.State == unreachable || n.State == incomplete {
|
||||
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||
return true
|
||||
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||
log.Infof(
|
||||
"network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
|
||||
neighbor.IPAddress,
|
||||
neighbor.LinkLayerAddress,
|
||||
neighbor.InterfaceAlias,
|
||||
neighbor.InterfaceIndex,
|
||||
n.InterfaceAlias,
|
||||
n.InterfaceIndex,
|
||||
stateFromInt(n.State),
|
||||
)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
entries, err := systemops.GetNeighbors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
|
||||
for _, entry := range entries {
|
||||
neighbours[entry.IPAddress] = entry
|
||||
}
|
||||
|
||||
return neighbours, nil
|
||||
}
|
||||
|
||||
func getRoutes() ([]systemops.Route, error) {
|
||||
entries, err := systemops.GetRoutes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func stateFromInt(state uint8) string {
|
||||
switch state {
|
||||
case unreachable:
|
||||
return "unreachable"
|
||||
case incomplete:
|
||||
return "incomplete"
|
||||
case probe:
|
||||
return "probe"
|
||||
case delay:
|
||||
return "delay"
|
||||
case stale:
|
||||
return "stale"
|
||||
case reachable:
|
||||
return "reachable"
|
||||
case permanent:
|
||||
return "permanent"
|
||||
case tbd:
|
||||
return "tbd"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func compareIntf(a, b *net.Interface) int {
|
||||
if a == nil && b == nil {
|
||||
return 0
|
||||
}
|
||||
if a == nil {
|
||||
return -1
|
||||
}
|
||||
if b == nil {
|
||||
return 1
|
||||
}
|
||||
return a.Index - b.Index
|
||||
}
|
||||
@@ -18,14 +18,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
iceKeepAliveDefault = 4 * time.Second
|
||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
)
|
||||
@@ -65,9 +68,6 @@ type ConnConfig struct {
|
||||
|
||||
NATExternalIPs []string
|
||||
|
||||
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
|
||||
UserspaceBind bool
|
||||
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
RosenpassPubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
@@ -127,23 +127,13 @@ type Conn struct {
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxy wgproxy.Proxy
|
||||
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
adapter iface.TunAdapter
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
sentExtraSrflx bool
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
type meta struct {
|
||||
protoSupport signal.FeaturesSupport
|
||||
}
|
||||
|
||||
// ModeMessage represents a connection mode chosen by the peer
|
||||
type ModeMessage struct {
|
||||
// Direct indicates that it decided to use a direct connection
|
||||
Direct bool
|
||||
connID nbnet.ConnectionID
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
}
|
||||
|
||||
// GetConf returns the connection config
|
||||
@@ -172,7 +162,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
remoteModeCh: make(chan ModeMessage, 1),
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
@@ -193,20 +182,22 @@ func (conn *Conn) reCreateAgent() error {
|
||||
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
Urls: conn.config.StunTurn,
|
||||
CandidateTypes: conn.candidateTypes(),
|
||||
FailedTimeout: &failedTimeout,
|
||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||
UDPMux: conn.config.UDPMux,
|
||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||
Net: transportNet,
|
||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||
KeepaliveInterval: &iceKeepAlive,
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
Urls: conn.config.StunTurn,
|
||||
CandidateTypes: conn.candidateTypes(),
|
||||
FailedTimeout: &failedTimeout,
|
||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||
UDPMux: conn.config.UDPMux,
|
||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||
Net: transportNet,
|
||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||
KeepaliveInterval: &iceKeepAlive,
|
||||
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||
}
|
||||
|
||||
if conn.config.DisableIPv6Discovery {
|
||||
@@ -214,7 +205,6 @@ func (conn *Conn) reCreateAgent() error {
|
||||
}
|
||||
|
||||
conn.agent, err = ice.NewAgent(agentConfig)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -234,6 +224,17 @@ func (conn *Conn) reCreateAgent() error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
|
||||
err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency())
|
||||
if err != nil {
|
||||
log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed setting binding response callback: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -251,7 +252,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
|
||||
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
||||
// Blocks until connection has been closed or connection timeout.
|
||||
// ConnStatus will be set accordingly
|
||||
func (conn *Conn) Open() error {
|
||||
func (conn *Conn) Open(ctx context.Context) error {
|
||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||
|
||||
peerState := State{
|
||||
@@ -259,6 +260,7 @@ func (conn *Conn) Open() error {
|
||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: conn.status,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -310,7 +312,7 @@ func (conn *Conn) Open() error {
|
||||
// at this point we received offer/answer and we are ready to gather candidates
|
||||
conn.mu.Lock()
|
||||
conn.status = StatusConnecting
|
||||
conn.ctx, conn.notifyDisconnected = context.WithCancel(context.Background())
|
||||
conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx)
|
||||
defer conn.notifyDisconnected()
|
||||
conn.mu.Unlock()
|
||||
|
||||
@@ -318,6 +320,7 @@ func (conn *Conn) Open() error {
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -326,7 +329,7 @@ func (conn *Conn) Open() error {
|
||||
|
||||
err = conn.agent.GatherCandidates()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("gather candidates: %v", err)
|
||||
}
|
||||
|
||||
// will block until connection succeeded
|
||||
@@ -343,11 +346,12 @@ func (conn *Conn) Open() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// dynamically set remote WireGuard port is other side specified a different one from the default one
|
||||
// dynamically set remote WireGuard port if other side specified a different one from the default one
|
||||
remoteWgPort := iface.DefaultWgPort
|
||||
if remoteOfferAnswer.WgListenPort != 0 {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
||||
remoteOfferAnswer.RosenpassAddr)
|
||||
@@ -372,6 +376,14 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||
}
|
||||
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
||||
conn.mu.Lock()
|
||||
@@ -385,7 +397,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
var endpoint net.Addr
|
||||
if isRelayCandidate(pair.Local) {
|
||||
log.Debugf("setup relay connection")
|
||||
conn.wgProxy = conn.wgProxyFactory.GetProxy()
|
||||
conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -397,13 +409,23 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connID = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
|
||||
log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
if conn.wgProxy != nil {
|
||||
_ = conn.wgProxy.CloseConn()
|
||||
if err := conn.wgProxy.CloseConn(); err != nil {
|
||||
log.Warnf("Failed to close turn connection: %v", err)
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("update peer: %w", err)
|
||||
}
|
||||
|
||||
conn.status = StatusConnected
|
||||
@@ -419,9 +441,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
Direct: !isRelayCandidate(pair.Local),
|
||||
RosenpassEnabled: rosenpassEnabled,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
@@ -437,6 +460,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "ios" {
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
|
||||
}
|
||||
@@ -488,6 +515,15 @@ func (conn *Conn) cleanup() error {
|
||||
// todo: is it problem if we try to remove a peer what is never existed?
|
||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
|
||||
if conn.connID != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connID); err != nil {
|
||||
log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
conn.connID = ""
|
||||
|
||||
if conn.notifyDisconnected != nil {
|
||||
conn.notifyDisconnected()
|
||||
conn.notifyDisconnected = nil
|
||||
@@ -503,6 +539,7 @@ func (conn *Conn) cleanup() error {
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -557,40 +594,39 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
|
||||
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
|
||||
// and then signals them to the remote peer
|
||||
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
||||
if candidate != nil {
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
err := conn.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: candidate.NetworkType().String(),
|
||||
Address: candidate.Address(),
|
||||
Port: relatedAdd.Port,
|
||||
Component: candidate.Component(),
|
||||
RelAddr: relatedAdd.Address,
|
||||
RelPort: relatedAdd.Port,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
err = conn.signalCandidate(extraSrflx)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||
return
|
||||
}
|
||||
conn.sentExtraSrflx = true
|
||||
}
|
||||
}()
|
||||
// nil means candidate gathering has been ended
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
err := conn.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !conn.shouldSendExtraSrflxCandidate(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
conn.sentExtraSrflx = true
|
||||
|
||||
go func() {
|
||||
err = conn.signalCandidate(extraSrflx)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
||||
@@ -672,7 +708,7 @@ func (conn *Conn) Close() error {
|
||||
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
|
||||
// engine adds a new Conn for 4 and 5
|
||||
// therefore peer 4 has 2 Conn objects
|
||||
log.Warnf("connection has been already closed or attempted closing not started coonection %s", conn.config.Key)
|
||||
log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key)
|
||||
return NewConnectionAlreadyClosed(conn.config.Key)
|
||||
}
|
||||
}
|
||||
@@ -715,7 +751,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
}
|
||||
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
|
||||
go func() {
|
||||
conn.mu.Lock()
|
||||
@@ -737,8 +773,21 @@ func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
||||
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
||||
protoSupport := signal.ParseFeaturesSupported(support)
|
||||
conn.meta.protoSupport = protoSupport
|
||||
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
|
||||
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: candidate.NetworkType().String(),
|
||||
Address: candidate.Address(),
|
||||
Port: relatedAdd.Port,
|
||||
Component: candidate.Component(),
|
||||
RelAddr: relatedAdd.Address,
|
||||
RelPort: relatedAdd.Port,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -35,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -50,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -87,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -123,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@@ -153,7 +154,7 @@ func TestConn_Status(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
||||
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
||||
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
||||
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
||||
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
||||
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
|
||||
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
||||
)
|
||||
|
||||
func iceKeepAlive() time.Duration {
|
||||
@@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration {
|
||||
return iceKeepAliveDefault
|
||||
}
|
||||
|
||||
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
||||
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
||||
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
||||
if err != nil {
|
||||
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
|
||||
@@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration {
|
||||
return iceDisconnectedTimeoutDefault
|
||||
}
|
||||
|
||||
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
||||
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
||||
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
||||
if err != nil {
|
||||
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
|
||||
@@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration {
|
||||
return time.Duration(disconnectedTimeoutSec) * time.Second
|
||||
}
|
||||
|
||||
func iceRelayAcceptanceMinWait() time.Duration {
|
||||
iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec)
|
||||
if iceRelayAcceptanceMinWaitEnv == "" {
|
||||
return iceRelayAcceptanceMinWaitDefault
|
||||
}
|
||||
|
||||
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
|
||||
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
|
||||
if err != nil {
|
||||
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
|
||||
return iceRelayAcceptanceMinWaitDefault
|
||||
}
|
||||
|
||||
return time.Duration(disconnectedTimeoutSec) * time.Second
|
||||
}
|
||||
|
||||
func hasICEForceRelayConn() bool {
|
||||
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
|
||||
return strings.ToLower(disconnectedTimeoutEnv) == "true"
|
||||
|
||||
@@ -2,15 +2,22 @@ package peer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// State contains the latest state of a peer
|
||||
type State struct {
|
||||
Mux *sync.RWMutex
|
||||
IP string
|
||||
PubKey string
|
||||
FQDN string
|
||||
@@ -25,7 +32,40 @@ type State struct {
|
||||
LastWireguardHandshake time.Time
|
||||
BytesTx int64
|
||||
BytesRx int64
|
||||
Latency time.Duration
|
||||
RosenpassEnabled bool
|
||||
routes map[string]struct{}
|
||||
}
|
||||
|
||||
// AddRoute add a single route to routes map
|
||||
func (s *State) AddRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
if s.routes == nil {
|
||||
s.routes = make(map[string]struct{})
|
||||
}
|
||||
s.routes[network] = struct{}{}
|
||||
}
|
||||
|
||||
// SetRoutes set state routes
|
||||
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
s.routes = routes
|
||||
}
|
||||
|
||||
// DeleteRoute removes a route from the network amp
|
||||
func (s *State) DeleteRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
delete(s.routes, network)
|
||||
}
|
||||
|
||||
// GetRoutes return routes map
|
||||
func (s *State) GetRoutes() map[string]struct{} {
|
||||
s.Mux.RLock()
|
||||
defer s.Mux.RUnlock()
|
||||
return s.routes
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -34,6 +74,7 @@ type LocalPeerState struct {
|
||||
PubKey string
|
||||
KernelInterface bool
|
||||
FQDN string
|
||||
Routes map[string]struct{}
|
||||
}
|
||||
|
||||
// SignalState contains the latest state of a signal connection
|
||||
@@ -56,6 +97,16 @@ type RosenpassState struct {
|
||||
Permissive bool
|
||||
}
|
||||
|
||||
// NSGroupState represents the status of a DNS server group, including associated domains,
|
||||
// whether it's enabled, and the last error message encountered during probing.
|
||||
type NSGroupState struct {
|
||||
ID string
|
||||
Servers []string
|
||||
Domains []string
|
||||
Enabled bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
type FullStatus struct {
|
||||
Peers []State
|
||||
@@ -64,25 +115,28 @@ type FullStatus struct {
|
||||
LocalPeerState LocalPeerState
|
||||
RosenpassState RosenpassState
|
||||
Relays []relay.ProbeResult
|
||||
NSGroupStates []NSGroupState
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
||||
|
||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||
@@ -93,11 +147,12 @@ type Status struct {
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +180,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
||||
PubKey: peerPubKey,
|
||||
ConnStatus: StatusDisconnected,
|
||||
FQDN: fqdn,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
@@ -137,7 +193,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return State{}, errors.New("peer not found")
|
||||
return State{}, iface.ErrPeerNotFound
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
@@ -171,6 +227,10 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.GetRoutes() != nil {
|
||||
peerState.SetRoutes(receivedState.GetRoutes())
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState, peerState)
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
@@ -275,6 +335,13 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
return ch
|
||||
}
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.localPeer
|
||||
}
|
||||
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Lock()
|
||||
@@ -361,6 +428,24 @@ func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
d.relayStates = relayResults
|
||||
}
|
||||
|
||||
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.resolvedDomainsStates[domain] = prefixes
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
delete(d.resolvedDomainsStates, domain)
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
@@ -376,6 +461,39 @@ func (d *Status) GetManagementState() ManagementState {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
if latency <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
peerState, ok := d.peers[pubKey]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
peerState.Latency = latency
|
||||
d.peers[pubKey] = peerState
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLoginRequired determines if a peer's login has expired.
|
||||
func (d *Status) IsLoginRequired() bool {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
// if peer is connected to the management then login is not expired
|
||||
if d.managementState {
|
||||
return false
|
||||
}
|
||||
|
||||
s, ok := gstatus.FromError(d.managementError)
|
||||
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
return SignalState{
|
||||
d.signalAddress,
|
||||
@@ -388,6 +506,16 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
return d.nsGroupStates
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.mux.Lock()
|
||||
@@ -399,6 +527,7 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
LocalPeerState: d.localPeer,
|
||||
Relays: d.GetRelayStates(),
|
||||
RosenpassState: d.GetRosenpassState(),
|
||||
NSGroupStates: d.GetDNSStates(),
|
||||
}
|
||||
|
||||
for _, status := range d.peers {
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
@@ -10,6 +10,9 @@ import (
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/turn/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
@@ -27,7 +30,15 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := stun.DialURI(uri, &stun.DialConfig{})
|
||||
net, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := stun.DialURI(uri, &stun.DialConfig{
|
||||
Net: net,
|
||||
})
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("dial: %w", err)
|
||||
return
|
||||
@@ -85,14 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
switch uri.Proto {
|
||||
case stun.ProtoTypeUDP:
|
||||
var err error
|
||||
conn, err = net.ListenPacket("udp", "")
|
||||
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("listen: %w", err)
|
||||
return
|
||||
}
|
||||
case stun.ProtoTypeTCP:
|
||||
dialer := net.Dialer{}
|
||||
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
|
||||
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("dial: %w", err)
|
||||
return
|
||||
@@ -109,12 +119,18 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
}
|
||||
}()
|
||||
|
||||
net, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
}
|
||||
cfg := &turn.ClientConfig{
|
||||
STUNServerAddr: turnServerAddr,
|
||||
TURNServerAddr: turnServerAddr,
|
||||
Conn: conn,
|
||||
Username: uri.Username,
|
||||
Password: uri.Password,
|
||||
Net: net,
|
||||
}
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
@@ -154,7 +170,7 @@ func ProbeAll(
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range relays {
|
||||
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
@@ -3,21 +3,25 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const minRangeBits = 7
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
direct bool
|
||||
latency time.Duration
|
||||
}
|
||||
|
||||
type routesUpdate struct {
|
||||
@@ -25,38 +29,48 @@ type routesUpdate struct {
|
||||
routes []*route.Route
|
||||
}
|
||||
|
||||
// RouteHandler defines the interface for handling routes
|
||||
type RouteHandler interface {
|
||||
String() string
|
||||
AddRoute(ctx context.Context) error
|
||||
RemoveRoute() error
|
||||
AddAllowedIPs(peerKey string) error
|
||||
RemoveAllowedIPs() error
|
||||
}
|
||||
|
||||
type clientNetwork struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
routes map[string]*route.Route
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{}
|
||||
chosenRoute *route.Route
|
||||
network netip.Prefix
|
||||
currentChosen *route.Route
|
||||
handler RouteHandler
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
routes: make(map[string]*route.Route),
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
network: network,
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
routePeerStatuses := make(map[string]routerPeerStatus)
|
||||
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
routePeerStatuses := make(map[route.ID]routerPeerStatus)
|
||||
for _, r := range c.routes {
|
||||
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
@@ -67,22 +81,37 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||
relayed: peerStatus.Relayed,
|
||||
direct: peerStatus.Direct,
|
||||
latency: peerStatus.Latency,
|
||||
}
|
||||
}
|
||||
return routePeerStatuses
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||
chosen := ""
|
||||
chosenScore := 0
|
||||
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
||||
// within a clientNetwork, taking into account peer connection status, route metrics, and
|
||||
// preference for non-relayed and direct connections.
|
||||
//
|
||||
// It follows these prioritization rules:
|
||||
// * Connected peers: Only routes with connected peers are considered.
|
||||
// * Metric: Routes with lower metrics (better) are prioritized.
|
||||
// * Non-relayed: Routes without relays are preferred.
|
||||
// * Direct connections: Routes with direct peer connections are favored.
|
||||
// * Latency: Routes with lower latency are prioritized.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||
chosen := route.ID("")
|
||||
chosenScore := float64(0)
|
||||
currScore := float64(0)
|
||||
|
||||
currID := ""
|
||||
if c.chosenRoute != nil {
|
||||
currID = c.chosenRoute.ID
|
||||
currID := route.ID("")
|
||||
if c.currentChosen != nil {
|
||||
currID = c.currentChosen.ID
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
tempScore := 0
|
||||
tempScore := float64(0)
|
||||
peerStatus, found := routePeerStatuses[r.ID]
|
||||
if !found || !peerStatus.connected {
|
||||
continue
|
||||
@@ -90,9 +119,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
|
||||
if r.Metric < route.MaxMetric {
|
||||
metricDiff := route.MaxMetric - r.Metric
|
||||
tempScore = metricDiff * 10
|
||||
tempScore = float64(metricDiff) * 10
|
||||
}
|
||||
|
||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||
latency := time.Second
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||
}
|
||||
tempScore += 1 - latency.Seconds()
|
||||
|
||||
if !peerStatus.relayed {
|
||||
tempScore++
|
||||
}
|
||||
@@ -101,7 +139,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
tempScore++
|
||||
}
|
||||
|
||||
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
|
||||
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
|
||||
chosen = r.ID
|
||||
chosenScore = tempScore
|
||||
}
|
||||
@@ -110,18 +148,31 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
chosen = r.ID
|
||||
chosenScore = tempScore
|
||||
}
|
||||
|
||||
if r.ID == currID {
|
||||
currScore = tempScore
|
||||
}
|
||||
}
|
||||
|
||||
if chosen == "" {
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
for _, r := range c.routes {
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
|
||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||
|
||||
} else if chosen != currID {
|
||||
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
|
||||
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
|
||||
case chosen != currID:
|
||||
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
return currID
|
||||
}
|
||||
var p string
|
||||
if rt := c.routes[chosen]; rt != nil {
|
||||
p = rt.Peer
|
||||
}
|
||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
|
||||
}
|
||||
|
||||
return chosen
|
||||
@@ -155,85 +206,103 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if state.ConnStatus != peer.StatusConnected {
|
||||
return nil
|
||||
}
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
|
||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
||||
c.network, err)
|
||||
}
|
||||
if c.currentChosen == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
|
||||
var err error
|
||||
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
if chosen == "" {
|
||||
err = c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
return err
|
||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
if newChosenID == "" {
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.chosenRoute = nil
|
||||
c.currentChosen = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||
return nil
|
||||
}
|
||||
// If the chosen route is the same as the current route, do nothing
|
||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
||||
c.currentChosen.IsEqual(c.routes[newChosenID]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil {
|
||||
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
if c.currentChosen == nil {
|
||||
// If they were not previously assigned to another peer, add routes to the system first
|
||||
if err := c.handler.AddRoute(c.ctx); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
|
||||
c.chosenRoute = c.routes[chosen]
|
||||
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
|
||||
if err != nil {
|
||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
c.currentChosen = c.routes[newChosenID]
|
||||
|
||||
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.addStateRoute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
@@ -241,7 +310,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
}
|
||||
|
||||
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
||||
updateMap := make(map[string]*route.Route)
|
||||
updateMap := make(map[route.ID]*route.Route)
|
||||
|
||||
for _, r := range update.routes {
|
||||
updateMap[r.ID] = r
|
||||
@@ -264,24 +333,23 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("stopping watcher for network %s", c.network)
|
||||
err := c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("received a routes update with smaller serial number, ignoring it")
|
||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("received a new client network route update for %s", c.network)
|
||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
||||
|
||||
c.handleUpdate(update)
|
||||
|
||||
@@ -289,10 +357,17 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
||||
}
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package routemanager
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -11,90 +13,90 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
statuses map[string]routerPeerStatus
|
||||
expectedRouteID string
|
||||
currentRoute *route.Route
|
||||
existingRoutes map[string]*route.Route
|
||||
statuses map[route.ID]routerPeerStatus
|
||||
expectedRouteID route.ID
|
||||
currentRoute route.ID
|
||||
existingRoutes map[route.ID]*route.Route
|
||||
}{
|
||||
{
|
||||
name: "one route",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "one connected routes with relayed and direct",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
direct: true,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "one connected routes with relayed and no direct",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
direct: false,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "no connected peers",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: false,
|
||||
relayed: false,
|
||||
direct: false,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "",
|
||||
},
|
||||
{
|
||||
name: "multiple connected peers with different metrics",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
@@ -106,7 +108,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
direct: true,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: 9000,
|
||||
@@ -118,12 +120,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "multiple connected peers with one relayed",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
@@ -135,7 +137,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
direct: true,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
@@ -147,12 +149,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "multiple connected peers with one direct",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
@@ -164,7 +166,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
direct: false,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
@@ -176,18 +178,172 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "multiple connected peers with different latencies",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
latency: 300 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
{
|
||||
name: "should ignore routes with latency 0",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
{
|
||||
name: "current route with similar score and similar but slightly worse latency should not change",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 15 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "current route with bad score should be changed to route with better score",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 200 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
{
|
||||
name: "current chosen route doesn't exist anymore",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 20 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "routeDoesntExistAnymore",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
ID: "routeDoesntExistAnymore",
|
||||
}
|
||||
if tc.currentRoute != "" {
|
||||
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||
}
|
||||
|
||||
// create new clientNetwork
|
||||
client := &clientNetwork{
|
||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
routes: tc.existingRoutes,
|
||||
chosenRoute: tc.currentRoute,
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
}
|
||||
|
||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||
|
||||
378
client/internal/routemanager/dynamic/route.go
Normal file
378
client/internal/routemanager/dynamic/route.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInterval = time.Minute
|
||||
|
||||
minInterval = 2 * time.Second
|
||||
failureInterval = 5 * time.Second
|
||||
|
||||
addAllowedIP = "add allowed IP %s: %w"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type resolveResult struct {
|
||||
domain domain.Domain
|
||||
prefix netip.Prefix
|
||||
err error
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
interval time.Duration
|
||||
dynamicDomains domainMap
|
||||
mu sync.Mutex
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) String() string {
|
||||
s, err := r.route.Domains.String()
|
||||
if err != nil {
|
||||
return r.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(ctx context.Context) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
ctx, r.cancel = context.WithCancel(ctx)
|
||||
|
||||
go r.startResolver(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
|
||||
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
|
||||
func (r *Route) RemoveRoute() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range r.dynamicDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
r.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
r.dynamicDomains = domainMap{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
r.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.currentPeerKey = ""
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) startResolver(ctx context.Context) {
|
||||
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
|
||||
|
||||
interval := r.interval
|
||||
if interval < minInterval {
|
||||
interval = minInterval
|
||||
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
// Use a lower ticker interval if the update fails
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
} else if interval > failureInterval {
|
||||
// Reset to the original interval if the update succeeds
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) error {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
return fmt.Errorf("update dynamic routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) resolveDomains() (domainMap, error) {
|
||||
results := make(chan resolveResult)
|
||||
go r.resolve(results)
|
||||
|
||||
resolved := domainMap{}
|
||||
var merr *multierror.Error
|
||||
|
||||
for result := range results {
|
||||
if result.err != nil {
|
||||
merr = multierror.Append(merr, result.err)
|
||||
} else {
|
||||
resolved[result.domain] = append(resolved[result.domain], result.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return resolved, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) resolve(results chan resolveResult) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, d := range r.route.Domains {
|
||||
wg.Add(1)
|
||||
go func(domain domain.Domain) {
|
||||
defer wg.Done()
|
||||
ips, err := net.LookupIP(string(domain))
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
}
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
|
||||
return
|
||||
}
|
||||
results <- resolveResult{domain: domain, prefix: prefix}
|
||||
}
|
||||
}(d)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}
|
||||
|
||||
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for domain, newPrefixes := range newDomains {
|
||||
oldPrefixes := r.dynamicDomains[domain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
addedPrefixes, err := r.addRoutes(domain, toAdd)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(addedPrefixes) > 0 {
|
||||
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
removedPrefixes, err := r.removeRoutes(toRemove)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(removedPrefixes) > 0 {
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||
r.dynamicDomains[domain] = updatedPrefixes
|
||||
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
var addedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return addedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
if r.route.KeepRoute {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var removedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
removedPrefixes = append(removedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return removedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
return fmt.Errorf(addAllowedIP, prefix, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
|
||||
prefixSet := make(map[netip.Prefix]struct{})
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
for _, prefix := range removedPrefixes {
|
||||
delete(prefixSet, prefix)
|
||||
}
|
||||
for _, prefix := range addedPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
|
||||
var combinedPrefixes []netip.Prefix
|
||||
for prefix := range prefixSet {
|
||||
combinedPrefixes = append(combinedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return combinedPrefixes
|
||||
}
|
||||
@@ -2,22 +2,36 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
@@ -26,29 +40,71 @@ type Manager interface {
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||
func NewManager(
|
||||
ctx context.Context,
|
||||
pubKey string,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface *iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
initialRoutes []*route.Route,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
sysOps := systemops.NewSysOps(wgInterface)
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (any, error) {
|
||||
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ any) error {
|
||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
)
|
||||
|
||||
dm.allowedIPsRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, peerKey string) (string, error) {
|
||||
// save peerKey to use it in the remove function
|
||||
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
||||
},
|
||||
func(prefix netip.Prefix, peerKey string) error {
|
||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
|
||||
return err
|
||||
}
|
||||
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.clientRoutes(initialRoutes)
|
||||
dm.notifier.setInitialClientRoutes(cr)
|
||||
@@ -56,9 +112,31 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
return dm
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
mgmtAddress := m.statusRecorder.GetManagementState().URL
|
||||
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
log.Info("Routing setup complete")
|
||||
return beforePeerHook, afterPeerHook, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
var err error
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -71,32 +149,53 @@ func (m *DefaultManager) Stop() {
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
|
||||
if m.routeRefCounter != nil {
|
||||
if err := m.routeRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing route ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
if m.allowedIPsRefCounter != nil {
|
||||
if err := m.allowedIPsRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
}
|
||||
|
||||
m.ctx = nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
return nil, nil, m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
m.notifier.onNewRoutes(newClientRoutesIDMap)
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.onNewRoutes(filteredClientRoutes)
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,24 +206,62 @@ func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeL
|
||||
|
||||
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
||||
func (m *DefaultManager) InitialRouteRange() []string {
|
||||
return m.notifier.initialRouteRanges()
|
||||
return m.notifier.getInitialRouteRanges()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
// GetRouteSelector returns the route selector
|
||||
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return m.routeSelector
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the client routes
|
||||
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
||||
return m.clientNetworks
|
||||
}
|
||||
|
||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||
func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
networks = m.routeSelector.FilterSelected(networks)
|
||||
|
||||
m.notifier.onNewRoutes(networks)
|
||||
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
if _, found := m.clientNetworks[id]; found {
|
||||
// don't touch existing client network watchers
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
}
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.cancel()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
@@ -136,15 +273,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
|
||||
newClientRoutesIDMap := make(route.HAMap)
|
||||
newServerRoutesMap := make(map[route.ID]*route.Route)
|
||||
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[networkID] = true
|
||||
ownNetworkIDs[haID] = true
|
||||
// only linux is supported for now
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
@@ -155,16 +292,12 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[networkID] {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < minRangeBits {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
if !ownNetworkIDs[haID] {
|
||||
if !isRouteSupported(newRoute) {
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,10 +305,44 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
||||
}
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifiesRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0)
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0, len(crMap))
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func isRouteSupported(route *route.Route) bool {
|
||||
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||
return true
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
// we skip this prefix management
|
||||
if route.Network.Bits() <= vars.MinRangeBits {
|
||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||
version.NetbirdVersion(), route.Network)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
|
||||
func resolveURLsToIPs(urls []string) []net.IP {
|
||||
var ips []net.IP
|
||||
for _, rawurl := range urls {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse url %s: %v", rawurl, err)
|
||||
continue
|
||||
}
|
||||
ipAddrs, err := net.LookupIP(u.Hostname())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
|
||||
continue
|
||||
}
|
||||
ips = append(ips, ipAddrs...)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
@@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1"
|
||||
|
||||
func TestManagerUpdateRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
clientNetworkWatchersExpectedAllowed int
|
||||
}{
|
||||
{
|
||||
name: "Should create 2 client networks",
|
||||
@@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
clientNetworkWatchersExpectedAllowed: 1,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route",
|
||||
@@ -405,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -414,7 +416,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop()
|
||||
|
||||
if testCase.removeSrvRouter {
|
||||
@@ -422,14 +428,18 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
if testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
||||
}
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||
|
||||
@@ -6,14 +6,22 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||
StopFunc func()
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||
@@ -22,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string {
|
||||
}
|
||||
|
||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
if m.UpdateRoutesFunc != nil {
|
||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||
}
|
||||
return fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
}
|
||||
|
||||
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||
if m.TriggerSelectionFunc != nil {
|
||||
m.TriggerSelectionFunc(networks)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
|
||||
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
if m.GetRouteSelectorFunc != nil {
|
||||
return m.GetRouteSelectorFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start mock implementation of Start from Manager interface
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -10,8 +11,8 @@ import (
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
initialRouteRangers []string
|
||||
routeRangers []string
|
||||
initialRouteRanges []string
|
||||
routeRanges []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
@@ -33,10 +34,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets = append(nets, r.Network.String())
|
||||
}
|
||||
sort.Strings(nets)
|
||||
n.initialRouteRangers = nets
|
||||
n.initialRouteRanges = nets
|
||||
}
|
||||
|
||||
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
|
||||
func (n *notifier) onNewRoutes(idMap route.HAMap) {
|
||||
newNets := make([]string, 0)
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
@@ -45,11 +46,18 @@ func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
if !n.hasDiff(n.initialRouteRangers, newNets) {
|
||||
return
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n.routeRangers = newNets
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
@@ -62,7 +70,7 @@ func (n *notifier) notify() {
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.routeRangers, ","))
|
||||
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
@@ -78,6 +86,20 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *notifier) initialRouteRanges() []string {
|
||||
return n.initialRouteRangers
|
||||
func (n *notifier) getInitialRouteRanges() []string {
|
||||
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
||||
}
|
||||
|
||||
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
|
||||
func addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||
ranges := inputRanges
|
||||
for _, r := range inputRanges {
|
||||
// we are intentionally adding the ipv6 default range in case of ipv4 default range
|
||||
// to ensure that all traffic is managed by the tunnel interface on android
|
||||
if r == "0.0.0.0/0" {
|
||||
ranges = append(ranges, "::/0")
|
||||
break
|
||||
}
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
|
||||
155
client/internal/routemanager/refcounter/refcounter.go
Normal file
155
client/internal/routemanager/refcounter/refcounter.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package refcounter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
|
||||
var ErrIgnore = errors.New("ignore")
|
||||
|
||||
type Ref[O any] struct {
|
||||
Count int
|
||||
Out O
|
||||
}
|
||||
|
||||
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
|
||||
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
|
||||
|
||||
type Counter[I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for prefixes
|
||||
refCountMap map[netip.Prefix]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
// idMap keeps track of the prefixes associated with an ID for removal
|
||||
idMap map[string][]netip.Prefix
|
||||
idMu sync.Mutex
|
||||
add AddFunc[I, O]
|
||||
remove RemoveFunc[I, O]
|
||||
}
|
||||
|
||||
// New creates a new Counter instance
|
||||
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
|
||||
return &Counter[I, O]{
|
||||
refCountMap: map[netip.Prefix]Ref[O]{},
|
||||
idMap: map[string][]netip.Prefix{},
|
||||
add: add,
|
||||
remove: remove,
|
||||
}
|
||||
}
|
||||
|
||||
// Increment increments the reference count for the given prefix.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
|
||||
// Call AddFunc only if it's a new prefix
|
||||
if ref.Count == 0 {
|
||||
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
|
||||
out, err := rm.add(prefix, in)
|
||||
|
||||
if errors.Is(err, ErrIgnore) {
|
||||
return ref, nil
|
||||
}
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.Out = out
|
||||
}
|
||||
|
||||
ref.Count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(prefix, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
rm.idMap[id] = append(rm.idMap[id], prefix)
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Decrement decrements the reference count for the given prefix.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[prefix]
|
||||
if !ok {
|
||||
log.Tracef("No reference found for prefix %s", prefix)
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
if ref.Count == 1 {
|
||||
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.Count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
delete(rm.idMap, id)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each prefix.
|
||||
func (rm *Counter[I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Tracef("Removing for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]Ref[O]{}
|
||||
|
||||
rm.idMap = map[string][]netip.Prefix{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
7
client/internal/routemanager/refcounter/types.go
Normal file
7
client/internal/routemanager/refcounter/types.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package refcounter
|
||||
|
||||
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
||||
type RouteRefCounter = Counter[any, any]
|
||||
|
||||
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
||||
type AllowedIPsRefCounter = Counter[string, string]
|
||||
@@ -3,7 +3,7 @@ package routemanager
|
||||
import "github.com/netbirdio/netbird/route"
|
||||
|
||||
type serverRouter interface {
|
||||
updateRoutes(map[string]*route.Route) error
|
||||
updateRoutes(map[route.ID]*route.Route) error
|
||||
removeFromServerNetwork(*route.Route) error
|
||||
cleanUp()
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
|
||||
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
||||
@@ -4,35 +4,40 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type defaultServerRouter struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[string]*route.Route
|
||||
firewall firewall.Manager
|
||||
wgInterface *iface.WGIface
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
firewall firewall.Manager
|
||||
wgInterface *iface.WGIface
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||
return &defaultServerRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[string]*route.Route),
|
||||
firewall: firewall,
|
||||
wgInterface: wgInterface,
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
firewall: firewall,
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||
serverRoutesToRemove := make([]route.ID, 0)
|
||||
|
||||
for routeID := range m.routes {
|
||||
update, found := routesMap[routeID]
|
||||
@@ -45,7 +50,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
||||
oldRoute := m.routes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.routes, routeID)
|
||||
@@ -59,14 +64,14 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.routes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
err := systemops.EnableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -78,16 +83,28 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
log.Infof("Not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
}
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
delete(state.Routes, route.Network.String())
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -95,16 +112,37 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
log.Infof("Not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.InsertRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
|
||||
m.routes[route.ID] = route
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -113,19 +151,45 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, r := range m.routes {
|
||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
|
||||
routerPair, err := routeToRouterPair(r)
|
||||
if err != nil {
|
||||
log.Warnf("failed to remove clean up route: %s", r.ID)
|
||||
log.Errorf("Failed to convert route to router pair: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
state.Routes = nil
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
return firewall.RouterPair{
|
||||
ID: route.ID,
|
||||
Source: parsed.String(),
|
||||
Destination: route.Network.Masked().String(),
|
||||
Masquerade: route.Masquerade,
|
||||
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
// TODO: add ipv6
|
||||
source := getDefaultPrefix(route.Network)
|
||||
|
||||
destination := route.Network.Masked().String()
|
||||
if route.IsDynamic() {
|
||||
// TODO: add ipv6
|
||||
destination = "0.0.0.0/0"
|
||||
}
|
||||
|
||||
return firewall.RouterPair{
|
||||
ID: string(route.ID),
|
||||
Source: source.String(),
|
||||
Destination: destination,
|
||||
Masquerade: route.Masquerade,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
|
||||
if prefix.Addr().Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
57
client/internal/routemanager/static/route.go
Normal file
57
client/internal/routemanager/static/route.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
// Route route methods
|
||||
func (r *Route) String() string {
|
||||
return r.route.Network.String()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(context.Context) error {
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) RemoveRoute() error {
|
||||
_, err := r.routeRefCounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
r.route.Network,
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
||||
103
client/internal/routemanager/sysctl/sysctl_linux.go
Normal file
103
client/internal/routemanager/sysctl/sysctl_linux.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// go:build !android
|
||||
package sysctl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := Set(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = Set(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := Set(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
// Cleanup resets sysctl settings to their original values.
|
||||
func Cleanup(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := Set(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
18
client/internal/routemanager/systemops/routeflags_bsd.go
Normal file
18
client/internal/routemanager/systemops/routeflags_bsd.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build darwin || dragonfly || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
19
client/internal/routemanager/systemops/routeflags_freebsd.go
Normal file
19
client/internal/routemanager/systemops/routeflags_freebsd.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build: freebsd
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
|
||||
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
27
client/internal/routemanager/systemops/systemops.go
Normal file
27
client/internal/routemanager/systemops/systemops.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
IP netip.Addr
|
||||
Intf *net.Interface
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user