mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
202 Commits
feature/ev
...
wg_bind_pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
940367e1c6 | ||
|
|
c80cb22cc0 | ||
|
|
d0e51dfc11 | ||
|
|
7e11842a5e | ||
|
|
1bf3e6feec | ||
|
|
c392fe44e4 | ||
|
|
fef24b4a4d | ||
|
|
bb1efe68aa | ||
|
|
bb147c2a7c | ||
|
|
4616bc5258 | ||
|
|
1803cf3678 | ||
|
|
9f35a7fb8d | ||
|
|
2eeed55c18 | ||
|
|
0343c5f239 | ||
|
|
251f2d7bc2 | ||
|
|
306e02d32b | ||
|
|
8375491708 | ||
|
|
e197b89ac3 | ||
|
|
6aba28ccb7 | ||
|
|
8f9826b207 | ||
|
|
0aad9169e9 | ||
|
|
1057cd211d | ||
|
|
32b345991a | ||
|
|
e903522f8c | ||
|
|
ea88ec6d27 | ||
|
|
2be1a82f4a | ||
|
|
fe1ea4a2d0 | ||
|
|
f14f34cf2b | ||
|
|
109481e26d | ||
|
|
18098e7a7d | ||
|
|
5993982cca | ||
|
|
86f9051a30 | ||
|
|
489892553a | ||
|
|
b05e30ac5a | ||
|
|
769388cd21 | ||
|
|
c54fb9643c | ||
|
|
5dc0ff42a5 | ||
|
|
45badd2c39 | ||
|
|
d3de035961 | ||
|
|
b2da0ae70f | ||
|
|
931c20c8fe | ||
|
|
2eaf4aa8d7 | ||
|
|
110067c00f | ||
|
|
32c96c15b8 | ||
|
|
ca1dc5ac88 | ||
|
|
ce775d59ae | ||
|
|
f273fe9f51 | ||
|
|
e08af7fcdf | ||
|
|
454240ca05 | ||
|
|
1343a3f00e | ||
|
|
2a79995706 | ||
|
|
e869882da1 | ||
|
|
6c8bb60632 | ||
|
|
4d7029d80c | ||
|
|
909f305728 | ||
|
|
5e2f66d591 | ||
|
|
a7519859bc | ||
|
|
9b000b89d5 | ||
|
|
5c1acdbf2f | ||
|
|
db3a9f0aa2 | ||
|
|
ecc4f8a10d | ||
|
|
03abdfa112 | ||
|
|
9746a7f61a | ||
|
|
4ec6d5d20b | ||
|
|
3bab745142 | ||
|
|
0ca3d27a80 | ||
|
|
c5942e6b33 | ||
|
|
726ffb5740 | ||
|
|
dfb7960cd4 | ||
|
|
ab0cf1b8aa | ||
|
|
8ebd6ce963 | ||
|
|
42ba0765c8 | ||
|
|
514403db37 | ||
|
|
488d338ce8 | ||
|
|
6a75ec4ab7 | ||
|
|
b66e984ddd | ||
|
|
c65a934107 | ||
|
|
55ebf93815 | ||
|
|
9e74f30d2f | ||
|
|
71d24e59e6 | ||
|
|
992cfe64e1 | ||
|
|
d1703479ff | ||
|
|
a27fe4326c | ||
|
|
e6292e3124 | ||
|
|
628b497e81 | ||
|
|
8f66dea11c | ||
|
|
de8608f99f | ||
|
|
9c5adfea2b | ||
|
|
8e4710763e | ||
|
|
82af60838e | ||
|
|
311b67fe5a | ||
|
|
94d39ab48c | ||
|
|
41a47be379 | ||
|
|
e30def175b | ||
|
|
e1ef091d45 | ||
|
|
511ba6d51f | ||
|
|
b852198f67 | ||
|
|
891ba277b1 | ||
|
|
747797271e | ||
|
|
628a201e31 | ||
|
|
731d3ae464 | ||
|
|
453643683d | ||
|
|
b8cab2882b | ||
|
|
6143b819c5 | ||
|
|
3b42d5e48a | ||
|
|
1d4dfa41d2 | ||
|
|
f8db5742b5 | ||
|
|
bc3cec23ec | ||
|
|
f03aadf064 | ||
|
|
292ee260ad | ||
|
|
2a1efbd0fd | ||
|
|
3bfa26b13b | ||
|
|
221934447e | ||
|
|
9ce8056b17 | ||
|
|
c65a5acab9 | ||
|
|
62de082961 | ||
|
|
c4d9b76634 | ||
|
|
b4bb5c6bb8 | ||
|
|
2b1965c941 | ||
|
|
83e7e30218 | ||
|
|
24310c63e2 | ||
|
|
ed4f90b6aa | ||
|
|
0e9610c5b2 | ||
|
|
ed470d7dbe | ||
|
|
cb8abacadd | ||
|
|
bcac5f7b32 | ||
|
|
95d87384ab | ||
|
|
ea3899e6d6 | ||
|
|
337d3edcc4 | ||
|
|
e914adb5cd | ||
|
|
2f2d45de9e | ||
|
|
b3f339c753 | ||
|
|
e0fc779f58 | ||
|
|
f64e0754ee | ||
|
|
fe22eb3b98 | ||
|
|
69be2a8071 | ||
|
|
1bda8fd563 | ||
|
|
1ab791e91b | ||
|
|
41948f7919 | ||
|
|
60f67076b0 | ||
|
|
c645171c40 | ||
|
|
f832c83a18 | ||
|
|
462a86cfcc | ||
|
|
8a130ec3f1 | ||
|
|
c26cd3b9fe | ||
|
|
9d7b515b26 | ||
|
|
f1f90807e4 | ||
|
|
5bb875a0fa | ||
|
|
9a88ed3cda | ||
|
|
8026c84c95 | ||
|
|
82059df324 | ||
|
|
23610db727 | ||
|
|
f984b8a091 | ||
|
|
4330bfd8ca | ||
|
|
5782496287 | ||
|
|
a0f2b5f591 | ||
|
|
0350faf75d | ||
|
|
9f951c8fb5 | ||
|
|
8276e0908a | ||
|
|
6539b591b6 | ||
|
|
014f1b841f | ||
|
|
b52afe8d42 | ||
|
|
f36869e97d | ||
|
|
78c6231c01 | ||
|
|
e75535d30b | ||
|
|
d8429c5c34 | ||
|
|
c3ed08c249 | ||
|
|
2f0b652dad | ||
|
|
d4214638a0 | ||
|
|
c962d29280 | ||
|
|
44af5be30f | ||
|
|
fe63a64b6e | ||
|
|
d31219ba89 | ||
|
|
756ce96da9 | ||
|
|
b64f5ffcb4 | ||
|
|
eb45310c8f | ||
|
|
d5dfed498b | ||
|
|
3fc89749c1 | ||
|
|
aecee361d0 | ||
|
|
f8273c3ce9 | ||
|
|
00a8092482 | ||
|
|
64dbd5fbfc | ||
|
|
b5217350cf | ||
|
|
3ec8274b8e | ||
|
|
494e56d1be | ||
|
|
9adadfade4 | ||
|
|
9e408b5bbc | ||
|
|
a0de9aa345 | ||
|
|
4406d50c18 | ||
|
|
5e3502bb83 | ||
|
|
793e4f1f29 | ||
|
|
dcf6533ed5 | ||
|
|
12ae2e93fc | ||
|
|
2bc3d88af3 | ||
|
|
afaf0660be | ||
|
|
d4d8c5f037 | ||
|
|
e5adc1eb23 | ||
|
|
44f612f121 | ||
|
|
f9dfafa9d9 | ||
|
|
ca62f6787a | ||
|
|
27f4993ce3 | ||
|
|
5c0b8a46f0 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -6,5 +6,6 @@
|
||||
- [ ] Is it a bug fix
|
||||
- [ ] Is a typo/documentation fix
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] Extended the README / documentation, if necessary
|
||||
|
||||
4
.github/workflows/golang-test-darwin.yml
vendored
4
.github/workflows/golang-test-darwin.yml
vendored
@@ -6,6 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: macos-latest
|
||||
|
||||
14
.github/workflows/golang-test-linux.yml
vendored
14
.github/workflows/golang-test-linux.yml
vendored
@@ -6,6 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
@@ -31,13 +35,13 @@ jobs:
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: Test
|
||||
run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -66,13 +70,13 @@ jobs:
|
||||
run: go mod tidy
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: go test -c -o iface-testing.bin ./iface/...
|
||||
run: go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: go test -c -o engine-testing.bin ./client/internal/*.go
|
||||
run: go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||
@@ -89,4 +93,4 @@ jobs:
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
60
.github/workflows/golang-test-windows.yml
vendored
60
.github/workflows/golang-test-windows.yml
vendored
@@ -6,47 +6,45 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
pre:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- run: bash -x wireguard_nt.sh
|
||||
working-directory: client
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: syso
|
||||
path: client/*.syso
|
||||
retention-days: 1
|
||||
|
||||
test:
|
||||
needs: pre
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
id: go
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
|
||||
- uses: actions/cache@v2
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
path: |
|
||||
%LocalAppData%\go-build
|
||||
~\go\pkg\mod
|
||||
~\AppData\Local\go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
|
||||
- uses: actions/download-artifact@v2
|
||||
with:
|
||||
name: syso
|
||||
path: iface\
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Test
|
||||
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
- run: choco install -y sysinternals
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
3
.github/workflows/golangci-lint.yml
vendored
3
.github/workflows/golangci-lint.yml
vendored
@@ -1,5 +1,8 @@
|
||||
name: golangci-lint
|
||||
on: [pull_request]
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
|
||||
60
.github/workflows/install-test-darwin.yml
vendored
Normal file
60
.github/workflows/install-test-darwin.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Test installation Darwin
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
env:
|
||||
SKIP_UI_APP: true
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
install-all:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
||||
echo "Error: NetBird UI is not installed"
|
||||
exit 1
|
||||
fi
|
||||
38
.github/workflows/install-test-linux.yml
vendored
Normal file
38
.github/workflows/install-test-linux.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Test installation Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
check_bin_install: [true, false]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename apt package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: |
|
||||
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
||||
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
26
.github/workflows/release.yml
vendored
26
.github/workflows/release.yml
vendored
@@ -9,8 +9,12 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.4"
|
||||
GORELEASER_VER: "v1.6.3"
|
||||
SIGN_PIPE_VER: "v0.0.6"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
@@ -21,10 +25,6 @@ jobs:
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Generate syso with DLL
|
||||
run: bash -x wireguard_nt.sh
|
||||
working-directory: client
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
@@ -57,7 +57,19 @@ jobs:
|
||||
with:
|
||||
username: netbirdio
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc amd64
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows rsrc arm64
|
||||
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows rsrc arm
|
||||
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
|
||||
- name: Generate windows rsrc 386
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
@@ -127,7 +139,7 @@ jobs:
|
||||
retention-days: 3
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-latest
|
||||
runs-on: macos-11
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
|
||||
15
.github/workflows/test-docker-compose-linux.yml
vendored
15
.github/workflows/test-docker-compose-linux.yml
vendored
@@ -6,6 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -59,6 +63,11 @@ jobs:
|
||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
||||
|
||||
run: |
|
||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
@@ -68,6 +77,12 @@ jobs:
|
||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||
grep UseIDToken management.json | grep false
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
|
||||
@@ -25,14 +25,20 @@ builds:
|
||||
- goos: windows
|
||||
goarch: 386
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
- id: netbird-mgmt
|
||||
dir: management
|
||||
env: [CGO_ENABLED=0]
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-mgmt
|
||||
goos:
|
||||
- linux
|
||||
@@ -41,7 +47,7 @@ builds:
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-signal
|
||||
@@ -55,7 +61,7 @@ builds:
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
archives:
|
||||
|
||||
@@ -10,7 +10,7 @@ builds:
|
||||
goarch:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
@@ -24,7 +24,7 @@ builds:
|
||||
goarch:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -H windowsgui
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ builds:
|
||||
- hardfloat
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! DNS support.</strong>
|
||||
<strong>:hatching_chick: New Release! Peer expiration.</strong>
|
||||
<a href="https://github.com/netbirdio/netbird/releases">
|
||||
Learn more
|
||||
</a>
|
||||
@@ -56,10 +56,10 @@ NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactiv
|
||||
- \[x] Remote SSH access without managing SSH keys.
|
||||
- \[x] Network Routes.
|
||||
- \[x] Private DNS.
|
||||
- \[x] Network Activity Monitoring.
|
||||
|
||||
**Coming soon:**
|
||||
- \[ ] Mobile clients.
|
||||
- \[ ] Network Activity Monitoring.
|
||||
|
||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||
|
||||
@@ -98,6 +98,7 @@ See a complete [architecture overview](https://netbird.io/docs/overview/architec
|
||||
|
||||
### Community projects
|
||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
FROM gcr.io/distroless/base:debug
|
||||
ENV WT_LOG_FILE=console
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
|
||||
SHELL ["/busybox/sh","-c"]
|
||||
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
|
||||
|
||||
129
client/android/client.go
Normal file
129
client/android/client.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
type ConnectionListener interface {
|
||||
peer.Listener
|
||||
}
|
||||
|
||||
// TunAdapter export internal TunAdapter for mobile
|
||||
type TunAdapter interface {
|
||||
iface.TunAdapter
|
||||
}
|
||||
|
||||
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||
type IFaceDiscover interface {
|
||||
stdnet.ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
func init() {
|
||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
|
||||
lvl, _ := log.ParseLevel("trace")
|
||||
log.SetLevel(lvl)
|
||||
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||
c.ctxCancelLock.Lock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
defer c.ctxCancel()
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
auth := NewAuthWithConfig(ctx, cfg)
|
||||
err = auth.login(urlOpener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
func (c *Client) Stop() {
|
||||
c.ctxCancelLock.Lock()
|
||||
defer c.ctxCancelLock.Unlock()
|
||||
if c.ctxCancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.ctxCancel()
|
||||
}
|
||||
|
||||
// PeersList return with the list of the PeerInfos
|
||||
func (c *Client) PeersList() *PeerInfoArray {
|
||||
|
||||
fullStatus := c.recorder.GetFullStatus()
|
||||
|
||||
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||
for n, p := range fullStatus.Peers {
|
||||
pi := PeerInfo{
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
p.ConnStatus.String(),
|
||||
p.Direct,
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
}
|
||||
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
// SetConnectionListener set the network connection listener
|
||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.SetConnectionListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove connection listener
|
||||
func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
229
client/android/login.go
Normal file
229
client/android/login.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// SSOListener is async listener for mobile framework
|
||||
type SSOListener interface {
|
||||
OnSuccess(bool)
|
||||
OnError(error)
|
||||
}
|
||||
|
||||
// ErrListener is async listener for mobile framework
|
||||
type ErrListener interface {
|
||||
OnSuccess()
|
||||
OnError(error)
|
||||
}
|
||||
|
||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
}
|
||||
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
config *internal.Config
|
||||
cfgPath string
|
||||
}
|
||||
|
||||
// NewAuth instantiate Auth struct and validate the management URL
|
||||
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
inputCfg := internal.ConfigInput{
|
||||
ManagementURL: mgmURL,
|
||||
}
|
||||
|
||||
cfg, err := internal.CreateInMemoryConfig(inputCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Auth{
|
||||
ctx: context.Background(),
|
||||
config: cfg,
|
||||
cfgPath: cfgPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
go func() {
|
||||
sso, err := a.saveConfigIfSSOSupported()
|
||||
if err != nil {
|
||||
listener.OnError(err)
|
||||
} else {
|
||||
listener.OnSuccess(sso)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
|
||||
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
|
||||
go func() {
|
||||
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
resultListener.OnSuccess()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return internal.WriteOutConfig(a.cfgPath, a.config)
|
||||
}
|
||||
|
||||
// Login try register the client on the server
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
go func() {
|
||||
err := a.login(urlOpener)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
resultListener.OnSuccess()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener) error {
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||
defer cancel()
|
||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
37
client/android/peer_notifier.go
Normal file
37
client/android/peer_notifier.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package android
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
type PeerInfo struct {
|
||||
IP string
|
||||
FQDN string
|
||||
ConnStatus string // Todo replace to enum
|
||||
Direct bool
|
||||
}
|
||||
|
||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||
type PeerInfoCollection interface {
|
||||
Add(s string) PeerInfoCollection
|
||||
Get(i int) string
|
||||
Size() int
|
||||
}
|
||||
|
||||
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||
type PeerInfoArray struct {
|
||||
items []PeerInfo
|
||||
}
|
||||
|
||||
// Add new PeerInfo to the collection
|
||||
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array PeerInfoArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
78
client/android/preferences.go
Normal file
78
client/android/preferences.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// Preferences export a subset of the internal config for gomobile
|
||||
type Preferences struct {
|
||||
configInput internal.ConfigInput
|
||||
}
|
||||
|
||||
// NewPreferences create new Preferences instance
|
||||
func NewPreferences(configPath string) *Preferences {
|
||||
ci := internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
}
|
||||
return &Preferences{ci}
|
||||
}
|
||||
|
||||
// GetManagementURL read url from config file
|
||||
func (p *Preferences) GetManagementURL() (string, error) {
|
||||
if p.configInput.ManagementURL != "" {
|
||||
return p.configInput.ManagementURL, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cfg.ManagementURL.String(), err
|
||||
}
|
||||
|
||||
// SetManagementURL store the given url and wait for commit
|
||||
func (p *Preferences) SetManagementURL(url string) {
|
||||
p.configInput.ManagementURL = url
|
||||
}
|
||||
|
||||
// GetAdminURL read url from config file
|
||||
func (p *Preferences) GetAdminURL() (string, error) {
|
||||
if p.configInput.AdminURL != "" {
|
||||
return p.configInput.AdminURL, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cfg.AdminURL.String(), err
|
||||
}
|
||||
|
||||
// SetAdminURL store the given url and wait for commit
|
||||
func (p *Preferences) SetAdminURL(url string) {
|
||||
p.configInput.AdminURL = url
|
||||
}
|
||||
|
||||
// GetPreSharedKey read preshared key from config file
|
||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
if p.configInput.PreSharedKey != nil {
|
||||
return *p.configInput.PreSharedKey, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cfg.PreSharedKey, err
|
||||
}
|
||||
|
||||
// SetPreSharedKey store the given key and wait for commit
|
||||
func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// Commit write out the changes into config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
}
|
||||
120
client/android/preferences_test.go
Normal file
120
client/android/preferences_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestPreferences_DefaultValues(t *testing.T) {
|
||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||
p := NewPreferences(cfgFile)
|
||||
defaultVar, err := p.GetAdminURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read default value: %s", err)
|
||||
}
|
||||
|
||||
if defaultVar != internal.DefaultAdminURL {
|
||||
t.Errorf("invalid default admin url: %s", defaultVar)
|
||||
}
|
||||
|
||||
defaultVar, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read default management URL: %s", err)
|
||||
}
|
||||
|
||||
if defaultVar != internal.DefaultManagementURL {
|
||||
t.Errorf("invalid default management url: %s", defaultVar)
|
||||
}
|
||||
|
||||
var preSharedKey string
|
||||
preSharedKey, err = p.GetPreSharedKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read default preshared key: %s", err)
|
||||
}
|
||||
|
||||
if preSharedKey != "" {
|
||||
t.Errorf("invalid preshared key: %s", preSharedKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||
exampleString := "exampleString"
|
||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||
p := NewPreferences(cfgFile)
|
||||
|
||||
p.SetAdminURL(exampleString)
|
||||
resp, err := p.GetAdminURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read admin url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleString {
|
||||
t.Errorf("unexpected admin url: %s", resp)
|
||||
}
|
||||
|
||||
p.SetManagementURL(exampleString)
|
||||
resp, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read managmenet url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleString {
|
||||
t.Errorf("unexpected managemenet url: %s", resp)
|
||||
}
|
||||
|
||||
p.SetPreSharedKey(exampleString)
|
||||
resp, err = p.GetPreSharedKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read preshared key: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleString {
|
||||
t.Errorf("unexpected preshared key: %s", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreferences_Commit(t *testing.T) {
|
||||
exampleURL := "https://myurl.com:443"
|
||||
examplePresharedKey := "topsecret"
|
||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||
p := NewPreferences(cfgFile)
|
||||
|
||||
p.SetAdminURL(exampleURL)
|
||||
p.SetManagementURL(exampleURL)
|
||||
p.SetPreSharedKey(examplePresharedKey)
|
||||
|
||||
err := p.Commit()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to save changes: %s", err)
|
||||
}
|
||||
|
||||
p = NewPreferences(cfgFile)
|
||||
resp, err := p.GetAdminURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read admin url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleURL {
|
||||
t.Errorf("unexpected admin url: %s", resp)
|
||||
}
|
||||
|
||||
resp, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read managmenet url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleURL {
|
||||
t.Errorf("unexpected managemenet url: %s", resp)
|
||||
}
|
||||
|
||||
resp, err = p.GetPreSharedKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read preshared key: %s", err)
|
||||
}
|
||||
|
||||
if resp != examplePresharedKey {
|
||||
t.Errorf("unexpected preshared key: %s", resp)
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ var downCmd = &cobra.Command{
|
||||
Use: "down",
|
||||
Short: "down netbird connections",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -20,7 +21,7 @@ var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "login to the Netbird Management Service (first run)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -38,7 +39,12 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.GetConfig(managementURL, adminURL, configPath, preSharedKey)
|
||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
PreSharedKey: &preSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
@@ -129,7 +135,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.AccessToken
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
@@ -147,7 +153,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config)
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
@@ -157,7 +163,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
mgmtURL := managementURL
|
||||
if mgmtURL == "" {
|
||||
mgmtURL = internal.ManagementURLDefault().String()
|
||||
mgmtURL = internal.DefaultManagementURL
|
||||
}
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your servver or use Setup Keys to login", mgmtURL)
|
||||
@@ -166,12 +172,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
}
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(
|
||||
providerConfig.ProviderConfig.Audience,
|
||||
providerConfig.ProviderConfig.ClientID,
|
||||
providerConfig.ProviderConfig.TokenEndpoint,
|
||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||
)
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
if err != nil {
|
||||
|
||||
@@ -24,6 +24,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
)
|
||||
|
||||
var (
|
||||
configPath string
|
||||
defaultConfigPathDir string
|
||||
@@ -41,6 +46,8 @@ var (
|
||||
adminURL string
|
||||
setupKey string
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
customDNSAddress string
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
@@ -80,12 +87,12 @@ func init() {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
}
|
||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
rootCmd.PersistentFlags().StringVar(&managementURL, "management-url", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.ManagementURLDefault().String()))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "https://app.netbird.io", "Admin Panel URL [http|https]://[host]:[port]")
|
||||
rootCmd.PersistentFlags().StringVar(&configPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVar(&setupKey, "setup-key", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
@@ -96,6 +103,19 @@ func init() {
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
||||
`or --external-ip-map ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringVar(&customDNSAddress, dnsResolverAddress, "",
|
||||
`Sets a custom address for NetBird's local DNS resolver. `+
|
||||
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||
)
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -115,8 +135,8 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
|
||||
}
|
||||
|
||||
// SetFlagsFromEnvVars reads and updates flag values from environment variables with prefix WT_
|
||||
func SetFlagsFromEnvVars() {
|
||||
flags := rootCmd.PersistentFlags()
|
||||
func SetFlagsFromEnvVars(cmd *cobra.Command) {
|
||||
flags := cmd.PersistentFlags()
|
||||
flags.VisitAll(func(f *pflag.Flag) {
|
||||
oldEnvVar := FlagNameToEnvVar(f.Name, "WT_")
|
||||
|
||||
|
||||
36
client/cmd/root_test.go
Normal file
36
client/cmd/root_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitCommands(t *testing.T) {
|
||||
helpFlag := "-h"
|
||||
commandArgs := [][]string{{"root", helpFlag}}
|
||||
for _, command := range rootCmd.Commands() {
|
||||
commandArgs = append(commandArgs, []string{command.Name(), command.Name(), helpFlag})
|
||||
for _, subcommand := range command.Commands() {
|
||||
commandArgs = append(commandArgs, []string{command.Name() + " " + subcommand.Name(), command.Name(), subcommand.Name(), helpFlag})
|
||||
}
|
||||
}
|
||||
|
||||
for _, args := range commandArgs {
|
||||
t.Run(fmt.Sprintf("Testing Command %s", args[0]), func(t *testing.T) {
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
t.Fatalf("got an panic error while running the command: %s -h. Error: %s", args[0], err)
|
||||
}
|
||||
}()
|
||||
|
||||
rootCmd.SetArgs(args[1:])
|
||||
rootCmd.SetOut(io.Discard)
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
t.Errorf("expected no error while running %s command, got %v", args[0], err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -54,7 +54,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, managementURL, adminURL, configPath, logFile)
|
||||
serverInstance := server.New(p.ctx, configPath, logFile)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
@@ -84,7 +84,7 @@ var runCmd = &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "runs Netbird as service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -118,7 +118,7 @@ var startCmd = &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "starts Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -153,7 +153,7 @@ var stopCmd = &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "stops Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -186,7 +186,7 @@ var restartCmd = &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "restarts Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ var installCmd = &cobra.Command{
|
||||
Use: "install",
|
||||
Short: "installs Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -86,7 +86,7 @@ var uninstallCmd = &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "uninstalls Netbird service from system",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
|
||||
@@ -4,15 +4,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -40,7 +42,8 @@ var sshCmd = &cobra.Command{
|
||||
},
|
||||
Short: "connect to a remote SSH server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -56,7 +59,9 @@ var sshCmd = &cobra.Command{
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
config, err := internal.ReadConfig("", "", configPath, nil)
|
||||
config, err := internal.UpdateConfig(internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,25 +2,74 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type peerStateDetailOutput struct {
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Connected int `json:"connected" yaml:"connected"`
|
||||
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type signalStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
}
|
||||
|
||||
type managementStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
}
|
||||
|
||||
type iceCandidateType struct {
|
||||
Local string `json:"local" yaml:"local"`
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
@@ -29,67 +78,99 @@ var (
|
||||
var statusCmd = &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "status of the Netbird Service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := parseFilters()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
daemonStatus := fmt.Sprintf("Daemon status: %s\n", resp.GetStatus())
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
|
||||
cmd.Printf("%s\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
||||
daemonStatus,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
fullStatus := fromProtoFullStatus(pbFullStatus)
|
||||
|
||||
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion(), ipv4Flag))
|
||||
|
||||
return nil
|
||||
},
|
||||
RunE: statusFunc,
|
||||
}
|
||||
|
||||
func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := parseFilters()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
resp, _ := getStatus(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
cmd.Printf("Daemon status: %s\n\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
||||
resp.GetStatus(),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
if ipv4Flag {
|
||||
cmd.Print(parseInterfaceIP(resp.GetFullStatus().GetLocalPeerState().GetIP()))
|
||||
return nil
|
||||
}
|
||||
|
||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||
|
||||
statusOutputString := ""
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||
case jsonFlag:
|
||||
statusOutputString, err = parseToJSON(outputInformationHolder)
|
||||
case yamlFlag:
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Print(statusOutputString)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
@@ -109,195 +190,229 @@ func parseFilters() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
|
||||
var fullStatus nbStatus.FullStatus
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
managementState := pbFullStatus.GetManagementState()
|
||||
fullStatus.ManagementState.URL = managementState.GetURL()
|
||||
fullStatus.ManagementState.Connected = managementState.GetConnected()
|
||||
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
fullStatus.SignalState.URL = signalState.GetURL()
|
||||
fullStatus.SignalState.Connected = signalState.GetConnected()
|
||||
|
||||
localPeerState := pbFullStatus.GetLocalPeerState()
|
||||
fullStatus.LocalPeerState.IP = localPeerState.GetIP()
|
||||
fullStatus.LocalPeerState.PubKey = localPeerState.GetPubKey()
|
||||
fullStatus.LocalPeerState.KernelInterface = localPeerState.GetKernelInterface()
|
||||
fullStatus.LocalPeerState.FQDN = localPeerState.GetFqdn()
|
||||
|
||||
var peersState []nbStatus.PeerState
|
||||
|
||||
for _, pbPeerState := range pbFullStatus.GetPeers() {
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
peerState := nbStatus.PeerState{
|
||||
IP: pbPeerState.GetIP(),
|
||||
PubKey: pbPeerState.GetPubKey(),
|
||||
ConnStatus: pbPeerState.GetConnStatus(),
|
||||
ConnStatusUpdate: timeLocal,
|
||||
Relayed: pbPeerState.GetRelayed(),
|
||||
Direct: pbPeerState.GetDirect(),
|
||||
LocalIceCandidateType: pbPeerState.GetLocalIceCandidateType(),
|
||||
RemoteIceCandidateType: pbPeerState.GetRemoteIceCandidateType(),
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
}
|
||||
peersState = append(peersState, peerState)
|
||||
managementOverview := managementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
Connected: managementState.GetConnected(),
|
||||
}
|
||||
|
||||
fullStatus.Peers = peersState
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
signalOverview := signalStateOutput{
|
||||
URL: signalState.GetURL(),
|
||||
Connected: signalState.GetConnected(),
|
||||
}
|
||||
|
||||
return fullStatus
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||
|
||||
overview := statusOutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string, flag bool) string {
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
connType := ""
|
||||
peersConnected := 0
|
||||
for _, pbPeerState := range peers {
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
peersConnected = peersConnected + 1
|
||||
|
||||
interfaceIP := fullStatus.LocalPeerState.IP
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
}
|
||||
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
peerState := peerStateDetailOutput{
|
||||
IP: pbPeerState.GetIP(),
|
||||
PubKey: pbPeerState.GetPubKey(),
|
||||
Status: pbPeerState.GetConnStatus(),
|
||||
LastStatusUpdate: timeLocal.UTC(),
|
||||
ConnType: connType,
|
||||
Direct: pbPeerState.GetDirect(),
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
},
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
}
|
||||
|
||||
sortPeersByIP(peersStateDetail)
|
||||
|
||||
peersOverview := peersStateOutput{
|
||||
Total: len(peersStateDetail),
|
||||
Connected: peersConnected,
|
||||
Details: peersStateDetail,
|
||||
}
|
||||
return peersOverview
|
||||
}
|
||||
|
||||
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
||||
if len(peersStateDetail) > 0 {
|
||||
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
||||
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
||||
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
||||
return iAddr.Compare(jAddr) == -1
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func parseInterfaceIP(interfaceIP string) string {
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s\n", ip)
|
||||
}
|
||||
|
||||
if ipv4Flag {
|
||||
return fmt.Sprintf("%s\n", ip)
|
||||
func parseToJSON(overview statusOutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
var (
|
||||
managementStatusURL = ""
|
||||
signalStatusURL = ""
|
||||
managementConnString = "Disconnected"
|
||||
signalConnString = "Disconnected"
|
||||
interfaceTypeString = "Userspace"
|
||||
)
|
||||
|
||||
if printDetail {
|
||||
managementStatusURL = fmt.Sprintf(" to %s", fullStatus.ManagementState.URL)
|
||||
signalStatusURL = fmt.Sprintf(" to %s", fullStatus.SignalState.URL)
|
||||
func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
if fullStatus.ManagementState.Connected {
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
|
||||
managementConnString := "Disconnected"
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
}
|
||||
|
||||
if fullStatus.SignalState.Connected {
|
||||
signalConnString := "Disconnected"
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
}
|
||||
|
||||
if fullStatus.LocalPeerState.KernelInterface {
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if fullStatus.LocalPeerState.IP == "" {
|
||||
} else if overview.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
parsedPeersString, peersConnected := parsePeers(fullStatus.Peers, printDetail)
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", peersConnected, len(fullStatus.Peers))
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"Daemon version: %s\n"+
|
||||
"CLI version: %s\n"+
|
||||
"%s"+ // daemon status
|
||||
"Management: %s%s\n"+
|
||||
"Signal: %s%s\n"+
|
||||
"Domain: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
daemonVersion,
|
||||
system.NetbirdVersion(),
|
||||
daemonStatus,
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
managementStatusURL,
|
||||
signalConnString,
|
||||
signalStatusURL,
|
||||
fullStatus.LocalPeerState.FQDN,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
peersCountString,
|
||||
)
|
||||
|
||||
if printDetail {
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
"%s\n"+
|
||||
"%s",
|
||||
parsedPeersString,
|
||||
summary,
|
||||
)
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func parsePeers(peers []nbStatus.PeerState, printDetail bool) (string, int) {
|
||||
var (
|
||||
peersString = ""
|
||||
peersConnected = 0
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers)
|
||||
summary := parseGeneralSummary(overview, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
"%s\n"+
|
||||
"%s",
|
||||
parsedPeersString,
|
||||
summary,
|
||||
)
|
||||
|
||||
if len(peers) > 0 {
|
||||
sort.SliceStable(peers, func(i, j int) bool {
|
||||
iAddr, _ := netip.ParseAddr(peers[i].IP)
|
||||
jAddr, _ := netip.ParseAddr(peers[j].IP)
|
||||
return iAddr.Compare(jAddr) == -1
|
||||
})
|
||||
}
|
||||
|
||||
connectedStatusString := peer.StatusConnected.String()
|
||||
|
||||
for _, peerState := range peers {
|
||||
peerConnectionStatus := false
|
||||
if peerState.ConnStatus == connectedStatusString {
|
||||
peersConnected = peersConnected + 1
|
||||
peerConnectionStatus = true
|
||||
}
|
||||
|
||||
if printDetail {
|
||||
|
||||
if skipDetailByFilters(peerState, peerConnectionStatus) {
|
||||
continue
|
||||
}
|
||||
|
||||
localICE := "-"
|
||||
remoteICE := "-"
|
||||
connType := "-"
|
||||
|
||||
if peerConnectionStatus {
|
||||
localICE = peerState.LocalIceCandidateType
|
||||
remoteICE = peerState.RemoteIceCandidateType
|
||||
connType = "P2P"
|
||||
if peerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
" Public key: %s\n"+
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" Direct: %t\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" Last connection update: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
peerState.ConnStatus,
|
||||
connType,
|
||||
peerState.Direct,
|
||||
localICE,
|
||||
remoteICE,
|
||||
peerState.ConnStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
|
||||
peersString = peersString + peerString
|
||||
}
|
||||
}
|
||||
return peersString, peersConnected
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState nbStatus.PeerState, isConnected bool) bool {
|
||||
func parsePeers(peers peersStateOutput) string {
|
||||
var (
|
||||
peersString = ""
|
||||
)
|
||||
|
||||
for _, peerState := range peers.Details {
|
||||
|
||||
localICE := "-"
|
||||
if peerState.IceCandidateType.Local != "" {
|
||||
localICE = peerState.IceCandidateType.Local
|
||||
}
|
||||
|
||||
remoteICE := "-"
|
||||
if peerState.IceCandidateType.Remote != "" {
|
||||
remoteICE = peerState.IceCandidateType.Remote
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
" Public key: %s\n"+
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" Direct: %t\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" Last connection update: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
peerState.Status,
|
||||
peerState.ConnType,
|
||||
peerState.Direct,
|
||||
localICE,
|
||||
remoteICE,
|
||||
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
|
||||
peersString = peersString + peerString
|
||||
}
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
|
||||
|
||||
301
client/cmd/status_test.go
Normal file
301
client/cmd/status_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var resp = &proto.StatusResponse{
|
||||
Status: "Connected",
|
||||
FullStatus: &proto.FullStatus{
|
||||
Peers: []*proto.PeerState{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
},
|
||||
SignalState: &proto.SignalState{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
},
|
||||
LocalPeerState: &proto.LocalPeerState{
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
}
|
||||
|
||||
var overview = statusOutputOverview{
|
||||
Peers: peersStateOutput{
|
||||
Total: 2,
|
||||
Connected: 2,
|
||||
Details: []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
FQDN: "peer-1.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||
ConnType: "P2P",
|
||||
Direct: true,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
FQDN: "peer-2.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||
ConnType: "Relayed",
|
||||
Direct: false,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: "0.14.1",
|
||||
ManagementState: managementStateOutput{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
},
|
||||
SignalState: signalStateOutput{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
},
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
convertedResult := convertToStatusOutputOverview(resp)
|
||||
|
||||
assert.Equal(t, overview, convertedResult)
|
||||
}
|
||||
|
||||
func TestSortingOfPeers(t *testing.T) {
|
||||
peers := []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.104",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.105",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.103",
|
||||
},
|
||||
}
|
||||
|
||||
sortPeersByIP(peers)
|
||||
|
||||
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
json, _ := parseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSON := "{\"" +
|
||||
"peers\":" +
|
||||
"{" +
|
||||
"\"total\":2," +
|
||||
"\"connected\":2," +
|
||||
"\"details\":" +
|
||||
"[" +
|
||||
"{" +
|
||||
"\"fqdn\":\"peer-1.awesome-domain.com\"," +
|
||||
"\"netbirdIp\":\"192.168.178.101\"," +
|
||||
"\"publicKey\":\"Pubkey1\"," +
|
||||
"\"status\":\"Connected\"," +
|
||||
"\"lastStatusUpdate\":\"2001-01-01T01:01:01Z\"," +
|
||||
"\"connectionType\":\"P2P\"," +
|
||||
"\"direct\":true," +
|
||||
"\"iceCandidateType\":" +
|
||||
"{" +
|
||||
"\"local\":\"\"," +
|
||||
"\"remote\":\"\"" +
|
||||
"}" +
|
||||
"}," +
|
||||
"{" +
|
||||
"\"fqdn\":\"peer-2.awesome-domain.com\"," +
|
||||
"\"netbirdIp\":\"192.168.178.102\"," +
|
||||
"\"publicKey\":\"Pubkey2\"," +
|
||||
"\"status\":\"Connected\"," +
|
||||
"\"lastStatusUpdate\":\"2002-02-02T02:02:02Z\"," +
|
||||
"\"connectionType\":\"Relayed\"," +
|
||||
"\"direct\":false," +
|
||||
"\"iceCandidateType\":" +
|
||||
"{" +
|
||||
"\"local\":\"relay\"," +
|
||||
"\"remote\":\"prflx\"" +
|
||||
"}" +
|
||||
"}" +
|
||||
"]" +
|
||||
"}," +
|
||||
"\"cliVersion\":\"development\"," +
|
||||
"\"daemonVersion\":\"0.14.1\"," +
|
||||
"\"management\":" +
|
||||
"{" +
|
||||
"\"url\":\"my-awesome-management.com:443\"," +
|
||||
"\"connected\":true" +
|
||||
"}," +
|
||||
"\"signal\":" +
|
||||
"{\"" +
|
||||
"url\":\"my-awesome-signal.com:443\"," +
|
||||
"\"connected\":true" +
|
||||
"}," +
|
||||
"\"netbirdIp\":\"192.168.178.100/16\"," +
|
||||
"\"publicKey\":\"Some-Pub-Key\"," +
|
||||
"\"usesKernelInterface\":true," +
|
||||
"\"fqdn\":\"some-localhost.awesome-domain.com\"" +
|
||||
"}"
|
||||
// @formatter:on
|
||||
|
||||
assert.Equal(t, expectedJSON, json)
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := parseToYAML(overview)
|
||||
|
||||
expectedYAML := "peers:\n" +
|
||||
" total: 2\n" +
|
||||
" connected: 2\n" +
|
||||
" details:\n" +
|
||||
" - fqdn: peer-1.awesome-domain.com\n" +
|
||||
" netbirdIp: 192.168.178.101\n" +
|
||||
" publicKey: Pubkey1\n" +
|
||||
" status: Connected\n" +
|
||||
" lastStatusUpdate: 2001-01-01T01:01:01Z\n" +
|
||||
" connectionType: P2P\n" +
|
||||
" direct: true\n" +
|
||||
" iceCandidateType:\n" +
|
||||
" local: \"\"\n" +
|
||||
" remote: \"\"\n" +
|
||||
" - fqdn: peer-2.awesome-domain.com\n" +
|
||||
" netbirdIp: 192.168.178.102\n" +
|
||||
" publicKey: Pubkey2\n" +
|
||||
" status: Connected\n" +
|
||||
" lastStatusUpdate: 2002-02-02T02:02:02Z\n" +
|
||||
" connectionType: Relayed\n" +
|
||||
" direct: false\n" +
|
||||
" iceCandidateType:\n" +
|
||||
" local: relay\n" +
|
||||
" remote: prflx\n" +
|
||||
"cliVersion: development\n" +
|
||||
"daemonVersion: 0.14.1\n" +
|
||||
"management:\n" +
|
||||
" url: my-awesome-management.com:443\n" +
|
||||
" connected: true\n" +
|
||||
"signal:\n" +
|
||||
" url: my-awesome-signal.com:443\n" +
|
||||
" connected: true\n" +
|
||||
"netbirdIp: 192.168.178.100/16\n" +
|
||||
"publicKey: Some-Pub-Key\n" +
|
||||
"usesKernelInterface: true\n" +
|
||||
"fqdn: some-localhost.awesome-domain.com\n"
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := "Peers detail:\n" +
|
||||
" peer-1.awesome-domain.com:\n" +
|
||||
" NetBird IP: 192.168.178.101\n" +
|
||||
" Public key: Pubkey1\n" +
|
||||
" Status: Connected\n" +
|
||||
" -- detail --\n" +
|
||||
" Connection type: P2P\n" +
|
||||
" Direct: true\n" +
|
||||
" ICE candidate (Local/Remote): -/-\n" +
|
||||
" Last connection update: 2001-01-01 01:01:01\n" +
|
||||
"\n" +
|
||||
" peer-2.awesome-domain.com:\n" +
|
||||
" NetBird IP: 192.168.178.102\n" +
|
||||
" Public key: Pubkey2\n" +
|
||||
" Status: Connected\n" +
|
||||
" -- detail --\n" +
|
||||
" Connection type: Relayed\n" +
|
||||
" Direct: false\n" +
|
||||
" ICE candidate (Local/Remote): relay/prflx\n" +
|
||||
" Last connection update: 2002-02-02 02:02:02\n" +
|
||||
"\n" +
|
||||
"Daemon version: 0.14.1\n" +
|
||||
"CLI version: development\n" +
|
||||
"Management: Connected to my-awesome-management.com:443\n" +
|
||||
"Signal: Connected to my-awesome-signal.com:443\n" +
|
||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||
"NetBird IP: 192.168.178.100/16\n" +
|
||||
"Interface type: Kernel\n" +
|
||||
"Peers count: 2/2 Connected\n"
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false)
|
||||
|
||||
expectedString := "Daemon version: 0.14.1\n" +
|
||||
"CLI version: development\n" +
|
||||
"Management: Connected\n" +
|
||||
"Signal: Connected\n" +
|
||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||
"NetBird IP: 192.168.178.100/16\n" +
|
||||
"Interface type: Kernel\n" +
|
||||
"Peers count: 2/2 Connected\n"
|
||||
|
||||
assert.Equal(t, expectedString, shortVersion)
|
||||
}
|
||||
|
||||
func TestParsingOfIP(t *testing.T) {
|
||||
InterfaceIP := "192.168.178.123/16"
|
||||
|
||||
parsedIP := parseInterfaceIP(InterfaceIP)
|
||||
|
||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -68,7 +69,12 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,7 +102,8 @@ func startClientDaemon(
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx, managementURL, adminURL, configPath, "")
|
||||
server := client.New(ctx,
|
||||
configPath, "")
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
332
client/cmd/up.go
332
client/cmd/up.go
@@ -3,126 +3,266 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
Short: "install, login and start Netbird client",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars()
|
||||
const (
|
||||
invalidInputType int = iota
|
||||
ipInputType
|
||||
interfaceInputType
|
||||
)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
var (
|
||||
foregroundMode bool
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
Short: "install, login and start Netbird client",
|
||||
RunE: upFunc,
|
||||
}
|
||||
)
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
err = validateNATExternalIPs(natExternalIPs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
if foregroundMode {
|
||||
return runInForegroundMode(ctx, cmd)
|
||||
}
|
||||
return runInDaemonMode(ctx, cmd)
|
||||
}
|
||||
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
PreSharedKey: &preSharedKey,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
log.Warnf("failed closing dameon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
config, err := internal.GetConfig(managementURL, adminURL, configPath, preSharedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
cmd.Println("Already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
var loginErr error
|
||||
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, nbStatus.NewRecorder())
|
||||
}
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing dameon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
cmd.Println("Already connected")
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
}
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
openURL(cmd, loginResp.VerificationURIComplete)
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
},
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNATExternalIPs(list []string) error {
|
||||
for _, element := range list {
|
||||
if element == "" {
|
||||
return fmt.Errorf("empty string is not a valid input for %s", externalIPMapFlag)
|
||||
}
|
||||
|
||||
subElements := strings.Split(element, "/")
|
||||
if len(subElements) > 2 {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
||||
}
|
||||
|
||||
if len(subElements) == 1 && !isValidIP(subElements[0]) {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
||||
}
|
||||
|
||||
last := 0
|
||||
for _, singleElement := range subElements {
|
||||
inputType, err := validateElement(singleElement)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be an IP string or a network name", singleElement, externalIPMapFlag)
|
||||
}
|
||||
if last == interfaceInputType && inputType == interfaceInputType {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should not contain two interface names", element, externalIPMapFlag)
|
||||
}
|
||||
last = inputType
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateElement(element string) (int, error) {
|
||||
if isValidIP(element) {
|
||||
return ipInputType, nil
|
||||
}
|
||||
validIface, err := isValidInterface(element)
|
||||
if err != nil {
|
||||
return invalidInputType, fmt.Errorf("unable to validate the network interface name, error: %s", err)
|
||||
}
|
||||
|
||||
if validIface {
|
||||
return interfaceInputType, nil
|
||||
}
|
||||
|
||||
return interfaceInputType, fmt.Errorf("invalid IP or network interface name not found")
|
||||
}
|
||||
|
||||
func isValidIP(ip string) bool {
|
||||
return net.ParseIP(ip) != nil
|
||||
}
|
||||
|
||||
func isValidInterface(name string) (bool, error) {
|
||||
netInterfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, iface := range netInterfaces {
|
||||
if iface.Name == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
var parsed []byte
|
||||
if modified {
|
||||
if !isValidAddrPort(customDNSAddress) {
|
||||
return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||
}
|
||||
if customDNSAddress == "" && logFile != "console" {
|
||||
parsed = []byte("empty")
|
||||
} else {
|
||||
parsed = []byte(customDNSAddress)
|
||||
}
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func isValidAddrPort(input string) bool {
|
||||
if input == "" {
|
||||
return true
|
||||
}
|
||||
_, err := netip.ParseAddrPort(input)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -11,7 +12,7 @@ var (
|
||||
Short: "prints Netbird version",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
cmd.Println(system.NetbirdVersion())
|
||||
cmd.Println(version.NetbirdVersion())
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
57
client/firewall/firewall.go
Normal file
57
client/firewall/firewall.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// GetRuleID returns the rule id
|
||||
GetRuleID() string
|
||||
}
|
||||
|
||||
// Direction is the direction of the traffic
|
||||
type Direction int
|
||||
|
||||
const (
|
||||
// DirectionSrc is the direction of the traffic from the source
|
||||
DirectionSrc Direction = iota
|
||||
// DirectionDst is the direction of the traffic from the destination
|
||||
DirectionDst
|
||||
)
|
||||
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
const (
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
// ActionDrop is the action to drop a packet
|
||||
ActionDrop
|
||||
)
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AddFiltering rule to the firewall
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
port *Port,
|
||||
direction Direction,
|
||||
action Action,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
DeleteRule(rule Rule) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
// TODO: migrate routemanager firewal actions to this interface
|
||||
}
|
||||
160
client/firewall/iptables/manager_linux.go
Normal file
160
client/firewall/iptables/manager_linux.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChainFilterName is the name of the chain that is used for filtering by the Netbird client
|
||||
ChainFilterName = "NETBIRD-ACL"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
ipv6Client *iptables.IPTables
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create() (*Manager, error) {
|
||||
m := &Manager{}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
m.ipv4Client = ipv4Client
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ip6tables is not installed in the system or not supported")
|
||||
}
|
||||
m.ipv6Client = ipv6Client
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset firewall: %s", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
port *fw.Port,
|
||||
direction fw.Direction,
|
||||
action fw.Action,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
client := m.client(ip)
|
||||
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||
}
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||
}
|
||||
}
|
||||
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||
return nil, fmt.Errorf("invalid port definition")
|
||||
}
|
||||
pv := strconv.Itoa(port.Values[0])
|
||||
if port.IsRange {
|
||||
pv += ":" + strconv.Itoa(port.Values[1])
|
||||
}
|
||||
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rule := &Rule{
|
||||
id: uuid.New().String(),
|
||||
specs: specs,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
client := m.ipv4Client
|
||||
if r.v6 {
|
||||
client = m.ipv6Client
|
||||
}
|
||||
return client.Delete("filter", ChainFilterName, r.specs...)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err)
|
||||
}
|
||||
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table, chain string) error {
|
||||
ok, err := client.ChainExists(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := client.ClearChain(table, ChainFilterName); err != nil {
|
||||
return fmt.Errorf("failed to clear chain: %w", err)
|
||||
}
|
||||
return client.DeleteChain(table, ChainFilterName)
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
table string, chain string, ip net.IP, port string,
|
||||
direction fw.Direction, action fw.Action, comment string,
|
||||
) (specs []string) {
|
||||
if direction == fw.DirectionSrc {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||
specs = append(specs, "-j", m.actionToStr(action))
|
||||
return append(specs, "-m", "comment", "--comment", comment)
|
||||
}
|
||||
|
||||
// client returns corresponding iptables client for the given ip
|
||||
func (m *Manager) client(ip net.IP) *iptables.IPTables {
|
||||
if ip.To4() != nil {
|
||||
return m.ipv4Client
|
||||
}
|
||||
return m.ipv6Client
|
||||
}
|
||||
|
||||
func (m *Manager) actionToStr(action fw.Action) string {
|
||||
if action == fw.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
105
client/firewall/iptables/manager_linux_test.go
Normal file
105
client/firewall/iptables/manager_linux_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manager, err := Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var rule1 fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...)
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Proto: fw.PortProtocolTCP,
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule1); err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule2); err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("failed to reset: %v", err)
|
||||
}
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
t.Errorf("failed to drop chain: %v", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
t.Errorf("chain '%v' still exists after Reset", ChainFilterName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) {
|
||||
exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...)
|
||||
if err != nil {
|
||||
t.Errorf("failed to check rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !exists && mustExists {
|
||||
t.Errorf("rule '%v' does not exist", rulespec)
|
||||
return
|
||||
}
|
||||
if exists && !mustExists {
|
||||
t.Errorf("rule '%v' exist", rulespec)
|
||||
return
|
||||
}
|
||||
}
|
||||
13
client/firewall/iptables/rule.go
Normal file
13
client/firewall/iptables/rule.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package iptables
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
id string
|
||||
specs []string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
24
client/firewall/port.go
Normal file
24
client/firewall/port.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package firewall
|
||||
|
||||
// PortProtocol is the protocol of the port
|
||||
type PortProtocol string
|
||||
|
||||
const (
|
||||
// PortProtocolTCP is the TCP protocol
|
||||
PortProtocolTCP PortProtocol = "tcp"
|
||||
|
||||
// PortProtocolUDP is the UDP protocol
|
||||
PortProtocolUDP PortProtocol = "udp"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
type Port struct {
|
||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||
IsRange bool
|
||||
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
|
||||
// Proto is the protocol of the port
|
||||
Proto PortProtocol
|
||||
}
|
||||
@@ -193,6 +193,7 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
||||
Sleep 3000
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
SetShellVarContext current
|
||||
|
||||
@@ -1,33 +1,42 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var managementURLDefault *url.URL
|
||||
const (
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
|
||||
ManagementLegacyPort = 33073
|
||||
// DefaultManagementURL points to the NetBird's cloud management endpoint
|
||||
DefaultManagementURL = "https://api.wiretrustee.com:443"
|
||||
// DefaultAdminURL points to NetBird's cloud management console
|
||||
DefaultAdminURL = "https://app.netbird.io:443"
|
||||
)
|
||||
|
||||
func ManagementURLDefault() *url.URL {
|
||||
return managementURLDefault
|
||||
}
|
||||
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
||||
|
||||
func init() {
|
||||
managementURL, err := ParseURL("Management URL", "https://api.wiretrustee.com:443")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
managementURLDefault = managementURL
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
PreSharedKey *string
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -42,7 +51,7 @@ type Config struct {
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
SSHKey string
|
||||
|
||||
// ExternalIP mappings, if different than the host interface IP
|
||||
//
|
||||
@@ -60,10 +69,68 @@ type Config struct {
|
||||
// "12.34.56.78/10.1.2.3" => interface IP 10.1.2.3 will be mapped to external IP of 12.34.56.78
|
||||
|
||||
NATExternalIPs []string
|
||||
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
|
||||
CustomDNSAddress string
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
if configFileIsExists(configPath) {
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = WriteOutConfig(configPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
if !configFileIsExists(input.ConfigPath) {
|
||||
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
||||
}
|
||||
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !configFileIsExists(input.ConfigPath) {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = WriteOutConfig(input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
return createNewConfig(input)
|
||||
}
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(path, config)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (*Config, error) {
|
||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
wgKey := generateKey()
|
||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
if err != nil {
|
||||
@@ -76,74 +143,59 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
|
||||
WgPort: iface.DefaultWgPort,
|
||||
IFaceBlackList: []string{},
|
||||
DisableIPv6Discovery: false,
|
||||
NATExternalIPs: input.NATExternalIPs,
|
||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
||||
}
|
||||
if managementURL != "" {
|
||||
URL, err := ParseURL("Management URL", managementURL)
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.ManagementURL = defaultManagementURL
|
||||
if input.ManagementURL != "" {
|
||||
URL, err := parseURL("Management URL", input.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.ManagementURL = URL
|
||||
} else {
|
||||
config.ManagementURL = managementURLDefault
|
||||
}
|
||||
|
||||
if preSharedKey != "" {
|
||||
config.PreSharedKey = preSharedKey
|
||||
if input.PreSharedKey != nil {
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
}
|
||||
|
||||
if adminURL != "" {
|
||||
newURL, err := ParseURL("Admin Panel URL", adminURL)
|
||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.AdminURL = defaultAdminURL
|
||||
if input.AdminURL != "" {
|
||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.AdminURL = newURL
|
||||
}
|
||||
|
||||
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-"}
|
||||
|
||||
err = util.WriteJson(configPath, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.IFaceBlackList = defaultInterfaceBlacklist
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ParseURL parses and validates management URL
|
||||
func ParseURL(serviceName, managementURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(managementURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing management URL %s: [%s]", managementURL, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parsedMgmtURL.Scheme != "https" && parsedMgmtURL.Scheme != "http" {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid %s URL provided %s. Supported format [http|https]://[host]:[port]",
|
||||
serviceName, managementURL)
|
||||
}
|
||||
|
||||
return parsedMgmtURL, err
|
||||
}
|
||||
|
||||
// ReadConfig reads existing config. In case provided managementURL is not empty overrides the read property
|
||||
func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string) (*Config, error) {
|
||||
func update(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
||||
}
|
||||
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refresh := false
|
||||
|
||||
if managementURL != "" && config.ManagementURL.String() != managementURL {
|
||||
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
|
||||
log.Infof("new Management URL provided, updated to %s (old value %s)",
|
||||
managementURL, config.ManagementURL)
|
||||
newURL, err := ParseURL("Management URL", managementURL)
|
||||
input.ManagementURL, config.ManagementURL)
|
||||
newURL, err := parseURL("Management URL", input.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -151,10 +203,10 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if adminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != adminURL) {
|
||||
if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
|
||||
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
|
||||
adminURL, config.AdminURL)
|
||||
newURL, err := ParseURL("Admin Panel URL", adminURL)
|
||||
input.AdminURL, config.AdminURL)
|
||||
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -162,12 +214,13 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if preSharedKey != nil && config.PreSharedKey != *preSharedKey {
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
||||
*preSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *preSharedKey
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
if err != nil {
|
||||
@@ -181,10 +234,19 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
refresh = true
|
||||
}
|
||||
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
||||
config.NATExternalIPs = input.NATExternalIPs
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.CustomDNSAddress != nil {
|
||||
config.CustomDNSAddress = string(input.CustomDNSAddress)
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if refresh {
|
||||
// since we have new management URL, we need to update config file
|
||||
if err := util.WriteJson(configPath, config); err != nil {
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -192,19 +254,32 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfig reads existing config or generates a new one
|
||||
func GetConfig(managementURL, adminURL, configPath, preSharedKey string) (*Config, error) {
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
log.Infof("generating new config %s", configPath)
|
||||
return createNewConfig(managementURL, adminURL, configPath, preSharedKey)
|
||||
} else {
|
||||
// don't overwrite pre-shared key if we receive asterisks from UI
|
||||
pk := &preSharedKey
|
||||
if preSharedKey == "**********" {
|
||||
pk = nil
|
||||
}
|
||||
return ReadConfig(managementURL, adminURL, configPath, pk)
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing %s URL %s: [%s]", serviceName, serviceURL, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parsedMgmtURL.Scheme != "https" && parsedMgmtURL.Scheme != "http" {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid %s URL provided %s. Supported format [http|https]://[host]:[port]",
|
||||
serviceName, serviceURL)
|
||||
}
|
||||
|
||||
if parsedMgmtURL.Port() == "" {
|
||||
switch parsedMgmtURL.Scheme {
|
||||
case "https":
|
||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
|
||||
case "http":
|
||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
|
||||
default:
|
||||
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
||||
}
|
||||
}
|
||||
|
||||
return parsedMgmtURL, err
|
||||
}
|
||||
|
||||
// generateKey generates a new Wireguard private key
|
||||
@@ -216,107 +291,15 @@ func generateKey() string {
|
||||
return key.String()
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig ProviderConfig
|
||||
// don't overwrite pre-shared key if we receive asterisks from UI
|
||||
func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||
if preSharedKey != nil && *preSharedKey == "**********" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type ProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
}
|
||||
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", config.ManagementURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", config.ManagementURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
} else {
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: ProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
},
|
||||
}
|
||||
|
||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isProviderConfigValid(config ProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
return nil
|
||||
func configFileIsExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
@@ -10,17 +10,34 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReadConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
// case 1: new default config has to be generated
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
||||
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
managementURL := "https://test.management.url:33071"
|
||||
adminURL := "https://app.admin.url"
|
||||
adminURL := "https://app.admin.url:443"
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
preSharedKey := "preSharedKey"
|
||||
|
||||
// case 1: new config has to be generated
|
||||
config, err := GetConfig(managementURL, adminURL, path, preSharedKey)
|
||||
// case 2: new config has to be generated
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &preSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -32,8 +49,13 @@ func TestGetConfig(t *testing.T) {
|
||||
t.Errorf("config file was expected to be created under path %s", path)
|
||||
}
|
||||
|
||||
// case 2: existing config -> fetch it
|
||||
config, err = GetConfig(managementURL, adminURL, path, preSharedKey)
|
||||
// case 3: existing config -> fetch it
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &preSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -41,9 +63,14 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 3: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = GetConfig(newManagementURL, adminURL, path, preSharedKey)
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &preSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -58,3 +85,40 @@ func TestGetConfig(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
|
||||
}
|
||||
|
||||
func TestHiddenPreSharedKey(t *testing.T) {
|
||||
hidden := "**********"
|
||||
samplePreSharedKey := "mysecretpresharedkey"
|
||||
tests := []struct {
|
||||
name string
|
||||
preSharedKey *string
|
||||
want string
|
||||
}{
|
||||
{"nil", nil, ""},
|
||||
{"hidden", &hidden, ""},
|
||||
{"filled", &samplePreSharedKey, samplePreSharedKey},
|
||||
}
|
||||
|
||||
// generate default cfg
|
||||
cfgFile := filepath.Join(t.TempDir(), "config.json")
|
||||
_, _ = UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: cfgFile,
|
||||
})
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: cfgFile,
|
||||
PreSharedKey: tt.preSharedKey,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get cfg: %s", err)
|
||||
}
|
||||
|
||||
if cfg.PreSharedKey != tt.want {
|
||||
t.Fatalf("invalid preshared key: '%s', expected: '%s' ", cfg.PreSharedKey, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,25 +6,24 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Status) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -60,9 +59,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
return err
|
||||
}
|
||||
|
||||
managementURL := config.ManagementURL.String()
|
||||
statusRecorder.MarkManagementDisconnected(managementURL)
|
||||
|
||||
defer statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
select {
|
||||
@@ -75,7 +72,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
|
||||
engineCtx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
statusRecorder.MarkManagementDisconnected(managementURL)
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
@@ -85,6 +82,9 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
||||
mgmClient.SetConnStateListener(mgmNotifier)
|
||||
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
@@ -103,12 +103,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
}
|
||||
return wrapErr(err)
|
||||
}
|
||||
statusRecorder.MarkManagementConnected(managementURL)
|
||||
statusRecorder.MarkManagementConnected()
|
||||
|
||||
localPeerState := nbStatus.LocalPeerState{
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
|
||||
@@ -119,8 +119,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||
)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected(signalURL)
|
||||
defer statusRecorder.MarkSignalDisconnected(signalURL)
|
||||
statusRecorder.UpdateSignalAddress(signalURL)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected()
|
||||
defer statusRecorder.MarkSignalDisconnected()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
@@ -135,7 +137,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
}
|
||||
}()
|
||||
|
||||
statusRecorder.MarkSignalConnected(signalURL)
|
||||
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
|
||||
signalClient.SetConnStateListener(signalNotifier)
|
||||
|
||||
statusRecorder.MarkSignalConnected()
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
@@ -145,7 +150,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder)
|
||||
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
@@ -155,7 +166,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
|
||||
<-engineCtx.Done()
|
||||
statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
@@ -187,7 +201,6 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
@@ -197,6 +210,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
WgPort: config.WgPort,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -245,17 +259,17 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
|
||||
const ManagementLegacyPort = 33073
|
||||
|
||||
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
|
||||
if config.ManagementURL.Hostname() != ManagementURLDefault().Hostname() {
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
@@ -272,7 +286,7 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
||||
|
||||
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
|
||||
|
||||
newURL, err := ParseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -306,7 +320,10 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := ReadConfig(newURL.String(), "", configPath, nil)
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
@@ -318,3 +335,15 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
var sri interface{} = statusRecorder
|
||||
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
||||
return mgmNotifier
|
||||
}
|
||||
|
||||
func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal.ConnStateNotifier {
|
||||
var sri interface{} = statusRecorder
|
||||
notifier, _ := sri.(signal.ConnStateNotifier)
|
||||
return notifier
|
||||
}
|
||||
|
||||
134
client/internal/device_auth.go
Normal file
134
client/internal/device_auth.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig ProviderConfig
|
||||
}
|
||||
|
||||
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type ProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: ProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isProviderConfigValid(config ProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,8 +3,9 @@ package dns
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -14,6 +15,7 @@ const (
|
||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||
)
|
||||
|
||||
const (
|
||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||
fileMaxLineCharsLimit = 256
|
||||
@@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||
|
||||
@@ -2,8 +2,9 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"strings"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
@@ -19,6 +20,7 @@ type hostDNSConfig struct {
|
||||
}
|
||||
|
||||
type domainConfig struct {
|
||||
disabled bool
|
||||
domain string
|
||||
matchOnly bool
|
||||
}
|
||||
@@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
|
||||
serverPort: port,
|
||||
}
|
||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||
if len(nsConfig.NameServers) == 0 {
|
||||
continue
|
||||
}
|
||||
if nsConfig.Primary {
|
||||
config.routeAll = true
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, dConf.domain)
|
||||
continue
|
||||
|
||||
@@ -2,10 +2,11 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if !dConf.matchOnly {
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct {
|
||||
|
||||
@@ -4,14 +4,15 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -75,12 +76,12 @@ func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager,
|
||||
}
|
||||
defer closeConn()
|
||||
var s string
|
||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s)
|
||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName())
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name())
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
@@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
||||
matchDomains []string
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||
continue
|
||||
|
||||
@@ -2,10 +2,11 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
@@ -16,7 +17,7 @@ type resolvconf struct {
|
||||
|
||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface.GetName(),
|
||||
ifaceName: wgInterface.Name(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -1,26 +1,6 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
import nbdns "github.com/netbirdio/netbird/dns"
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
@@ -28,338 +8,3 @@ type Server interface {
|
||||
Stop()
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
}
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registrationMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
}
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler dns.Handler
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
stop: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
|
||||
hostmanager, err := newHostManager(wgInterface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defaultServer.hostManager = hostmanager
|
||||
return defaultServer, err
|
||||
}
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *DefaultServer) Start() {
|
||||
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err = s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.GetAddress().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.stop()
|
||||
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
err = s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
|
||||
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
}
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
localRecords[key] = record
|
||||
}
|
||||
}
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
|
||||
}
|
||||
handler := &upstreamResolver{
|
||||
parentCTX: s.ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamTimeout: defaultUpstreamTimeout,
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
}
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registrationMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = struct{}{}
|
||||
}
|
||||
|
||||
for key := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
s.deregisterMux(key)
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
32
client/internal/dns/server_android.go
Normal file
32
client/internal/dns/server_android.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// DefaultServer dummy dns server
|
||||
type DefaultServer struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer On Android the DNS feature is not supported yet
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||
return &DefaultServer{}, nil
|
||||
}
|
||||
|
||||
// Start dummy implementation
|
||||
func (s DefaultServer) Start() {
|
||||
|
||||
}
|
||||
|
||||
// Stop dummy implementation
|
||||
func (s DefaultServer) Stop() {
|
||||
|
||||
}
|
||||
|
||||
// UpdateDNSServer dummy implementation
|
||||
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
465
client/internal/dns/server_nonandroid.go
Normal file
465
client/internal/dns/server_nonandroid.go
Normal file
@@ -0,0 +1,465 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
upstreamCtxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registrationMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler dns.Handler
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
if err != nil {
|
||||
stop()
|
||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||
}
|
||||
addrPort = &parsedAddrPort
|
||||
}
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
runtimePort: defaultPort,
|
||||
customAddress: addrPort,
|
||||
}
|
||||
|
||||
hostmanager, err := newHostManager(wgInterface)
|
||||
if err != nil {
|
||||
stop()
|
||||
return nil, err
|
||||
}
|
||||
defaultServer.hostManager = hostmanager
|
||||
return defaultServer, err
|
||||
}
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *DefaultServer) Start() {
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
} else {
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
}
|
||||
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.ctxCancel()
|
||||
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
err = s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
}
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
localRecords[key] = record
|
||||
}
|
||||
}
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
// clean up the previous upstream resolver
|
||||
if s.upstreamCtxCancel != nil {
|
||||
s.upstreamCtxCancel()
|
||||
}
|
||||
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
log.Warn("received a nameserver group with empty nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
var ctx context.Context
|
||||
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||
|
||||
handler := newUpstreamResolver(ctx)
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
// when upstream fails to resolve domain several times over all it servers
|
||||
// it will calls this hook to exclude self from the configuration and
|
||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||
// because it is temporal deactivation until next try
|
||||
//
|
||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
}
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registrationMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = struct{}{}
|
||||
}
|
||||
|
||||
for key := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
s.deregisterMux(key)
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
removeIndex[domain] = -1
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.routeAll = false
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
}
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Debug("reactivate temporary disabled nameserver group")
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.routeAll = true
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -3,15 +3,18 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
@@ -25,7 +28,6 @@ var zoneRecords = []nbdns.SimpleRecord{
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
@@ -200,7 +202,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU)
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -214,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -265,91 +271,170 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
dnsServer := getDefaultServerWithNoHostManager("127.0.0.1")
|
||||
|
||||
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
|
||||
// todo review why this test is not working only on github actions workflows
|
||||
t.Skip("skipping test in Windows CI workflows.")
|
||||
}
|
||||
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
dnsServer.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if !dnsServer.listenerIsRunning {
|
||||
t.Fatal("dns server listener is not running")
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
// retry test before exit, for slower systems
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
testCases := []struct {
|
||||
name string
|
||||
addrPort string
|
||||
}{
|
||||
{
|
||||
name: "Should Pass With Port Discovery",
|
||||
},
|
||||
{
|
||||
name: "Should Pass With Custom Port",
|
||||
addrPort: "127.0.0.1:3535",
|
||||
},
|
||||
}
|
||||
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to the server, error: %v", err)
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
||||
|
||||
t.Log(ips)
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
dnsServer.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if !dnsServer.listenerIsRunning {
|
||||
t.Fatal("dns server listener is not running")
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
|
||||
}
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
dnsServer.Stop()
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
|
||||
defer cancel()
|
||||
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
|
||||
if err == nil {
|
||||
t.Fatalf("we should encounter an error when querying a stopped server")
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
// retry test before exit, for slower systems
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to the server, error: %v", err)
|
||||
}
|
||||
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
|
||||
}
|
||||
|
||||
dnsServer.Stop()
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
|
||||
defer cancel()
|
||||
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
|
||||
if err == nil {
|
||||
t.Fatalf("we should encounter an error when querying a stopped server")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultServerWithNoHostManager(ip string) *DefaultServer {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
dnsMux: dns.DefaultServeMux,
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
hostManager: hostManager,
|
||||
currentConfig: hostDNSConfig{
|
||||
domains: []domainConfig{
|
||||
{false, "domain0", false},
|
||||
{false, "domain1", false},
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.domains {
|
||||
if item.disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||
Domains: []string{"domain1"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
deactivate()
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.domains {
|
||||
if item.disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||
}
|
||||
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
for _, item := range server.currentConfig.domains {
|
||||
if item.disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||
mux := dns.NewServeMux()
|
||||
listenIP := defaultIP
|
||||
if ip != "" {
|
||||
listenIP = ip
|
||||
|
||||
var parsedAddrPort *netip.AddrPort
|
||||
if addrPort != "" {
|
||||
parsed, err := netip.ParseAddrPort(addrPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedAddrPort = &parsed
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", ip, defaultPort),
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(context.TODO())
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
return &DefaultServer{
|
||||
ctx: ctx,
|
||||
stop: stop,
|
||||
ctxCancel: cancel,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
runtimePort: defaultPort,
|
||||
runtimeIP: listenIP,
|
||||
customAddress: parsedAddrPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,16 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,7 +51,7 @@ type systemdDbusLinkDomainsInput struct {
|
||||
}
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface.GetName())
|
||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
domainsInput []systemdDbusLinkDomainsInput
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dns.Fqdn(dConf.domain),
|
||||
MatchOnly: dConf.matchOnly,
|
||||
|
||||
@@ -3,44 +3,73 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultUpstreamTimeout = 15 * time.Second
|
||||
const (
|
||||
failsTillDeact = int32(3)
|
||||
reactivatePeriod = time.Minute
|
||||
upstreamTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
parentCTX context.Context
|
||||
upstreamClient *dns.Client
|
||||
upstreamServers []string
|
||||
upstreamTimeout time.Duration
|
||||
ctx context.Context
|
||||
upstreamClient *dns.Client
|
||||
upstreamServers []string
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
failsTillDeact int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
|
||||
deactivate func()
|
||||
reactivate func()
|
||||
}
|
||||
|
||||
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
|
||||
return &upstreamResolver{
|
||||
ctx: ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
defer u.checkUpstreamFails()
|
||||
|
||||
log.Tracef("received an upstream question: %#v", r.Question[0])
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
|
||||
select {
|
||||
case <-u.parentCTX.Done():
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Warn("got an error while connecting to upstream")
|
||||
continue
|
||||
}
|
||||
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
|
||||
u.failsCount.Add(1)
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Error("got an error while querying the upstream")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
if err != nil {
|
||||
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
|
||||
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
||||
}
|
||||
// count the fails only if they happen sequentially
|
||||
u.failsCount.Store(0)
|
||||
return
|
||||
}
|
||||
log.Errorf("all queries to the upstream nameservers failed with timeout")
|
||||
u.failsCount.Add(1)
|
||||
log.Error("all queries to the upstream nameservers failed with timeout")
|
||||
}
|
||||
|
||||
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||
//
|
||||
// If fails count is greater that failsTillDeact, upstream resolving
|
||||
// will be disabled for reactivatePeriod, after that time period fails counter
|
||||
// will be reset and upstream will be reactivated.
|
||||
func (u *upstreamResolver) checkUpstreamFails() {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||
u.deactivate()
|
||||
u.disabled = true
|
||||
go u.waitUntilReactivation()
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilReactivation reset fails counter and activates upstream resolving
|
||||
func (u *upstreamResolver) waitUntilReactivation() {
|
||||
timer := time.NewTimer(u.reactivatePeriod)
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
log.Info("upstream resolving is reactivated")
|
||||
u.failsCount.Store(0)
|
||||
u.reactivate()
|
||||
u.disabled = false
|
||||
}
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: defaultUpstreamTimeout,
|
||||
timeout: upstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
@@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: defaultUpstreamTimeout,
|
||||
timeout: upstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
//{
|
||||
@@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver := &upstreamResolver{
|
||||
parentCTX: ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamServers: testCase.InputServers,
|
||||
upstreamTimeout: testCase.timeout,
|
||||
}
|
||||
resolver := newUpstreamResolver(ctx)
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
} else {
|
||||
@@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
resolver := newUpstreamResolver(context.TODO())
|
||||
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
||||
resolver.failsTillDeact = 0
|
||||
resolver.reactivatePeriod = time.Microsecond * 100
|
||||
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||
}
|
||||
|
||||
failed := false
|
||||
resolver.deactivate = func() {
|
||||
failed = true
|
||||
}
|
||||
|
||||
reactivated := false
|
||||
resolver.reactivate = func() {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
t.Errorf("resolver should be disabled")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
if !reactivated {
|
||||
t.Errorf("expected that resolving was reactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.failsCount.Load() != 0 {
|
||||
t.Errorf("fails count after reactivation should be 0")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,24 +12,23 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -70,6 +69,8 @@ type EngineConfig struct {
|
||||
SSHKey []byte
|
||||
|
||||
NATExternalIPs []string
|
||||
|
||||
CustomDNSAddress string
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -84,7 +85,9 @@ type Engine struct {
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
|
||||
// STUNs is a list of STUN servers used by ICE
|
||||
STUNs []*ice.URL
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
@@ -107,7 +110,7 @@ type Engine struct {
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
|
||||
statusRecorder *nbstatus.Status
|
||||
statusRecorder *peer.Status
|
||||
|
||||
routeManager routemanager.Manager
|
||||
|
||||
@@ -124,16 +127,17 @@ type Peer struct {
|
||||
func NewEngine(
|
||||
ctx context.Context, cancel context.CancelFunc,
|
||||
signalClient signal.Client, mgmClient mgm.Client,
|
||||
config *EngineConfig, statusRecorder *nbstatus.Status,
|
||||
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
signal: signalClient,
|
||||
mgmClient: mgmClient,
|
||||
peerConns: map[string]*peer.Conn{},
|
||||
peerConns: make(map[string]*peer.Conn),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*ice.URL{},
|
||||
TURNs: []*ice.URL{},
|
||||
networkSerial: 0,
|
||||
@@ -155,114 +159,89 @@ func (e *Engine) Stop() error {
|
||||
// Removing peers happens in the conn.CLose() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface.Interface != nil {
|
||||
err = e.wgInterface.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMux != nil {
|
||||
if err := e.udpMux.Close(); err != nil {
|
||||
log.Debugf("close udp mux: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxSrflx != nil {
|
||||
if err := e.udpMuxSrflx.Close(); err != nil {
|
||||
log.Debugf("close server reflexive udp mux: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxConn != nil {
|
||||
if err := e.udpMuxConn.Close(); err != nil {
|
||||
log.Debugf("close udp mux connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxConnSrflx != nil {
|
||||
if err := e.udpMuxConnSrflx.Close(); err != nil {
|
||||
log.Debugf("close server reflexive udp mux connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
}
|
||||
|
||||
e.close()
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
wgIfaceName := e.config.WgIfaceName
|
||||
wgIFaceName := e.config.WgIfaceName
|
||||
wgAddr := e.config.WgAddr
|
||||
myPrivateKey := e.config.WgPrivateKey
|
||||
var err error
|
||||
|
||||
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
transportNet, err := e.newStdNet()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
networkName := "udp"
|
||||
if e.config.DisableIPv6Discovery {
|
||||
networkName = "udp4"
|
||||
}
|
||||
|
||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
|
||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
|
||||
|
||||
err = e.wgInterface.Create()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
|
||||
if err != nil {
|
||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
|
||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
if e.wgInterface.IsUserspaceBind() {
|
||||
iceBind := e.wgInterface.GetBind()
|
||||
udpMux, err := iceBind.GetICEMux()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
e.udpMux = udpMux.UDPMuxDefault
|
||||
e.udpMuxSrflx = udpMux
|
||||
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
||||
} else {
|
||||
networkName := "udp"
|
||||
if e.config.DisableIPv6Discovery {
|
||||
networkName = "udp4"
|
||||
}
|
||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
udpMuxParams := ice.UDPMuxParams{
|
||||
UDPConn: e.udpMuxConn,
|
||||
Net: transportNet,
|
||||
}
|
||||
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
||||
}
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||
|
||||
if e.dnsServer == nil {
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface)
|
||||
// todo fix custom address
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
@@ -378,42 +357,6 @@ func (e *Engine) removePeer(peerKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerConnectionStatus returns a connection Status or nil if peer connection wasn't found
|
||||
func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
|
||||
conn, exists := e.peerConns[peerKey]
|
||||
if exists && conn != nil {
|
||||
return conn.Status()
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
func (e *Engine) GetPeers() []string {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
peers := []string{}
|
||||
for s := range e.peerConns {
|
||||
peers = append(peers, s)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func (e *Engine) GetConnectedPeers() []string {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
peers := []string{}
|
||||
for s, conn := range e.peerConns {
|
||||
if conn.Status() == peer.StatusConnected {
|
||||
peers = append(peers, s)
|
||||
}
|
||||
}
|
||||
|
||||
return peers
|
||||
}
|
||||
|
||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
||||
err := s.Send(&sProto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
@@ -430,6 +373,10 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendSignal(message *sProto.Message, s signal.Client) error {
|
||||
return s.Send(message)
|
||||
}
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
||||
var t sProto.Body_Type
|
||||
@@ -446,6 +393,10 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// indicates message support in gRPC
|
||||
msg.Body.FeaturesSupported = []uint32{signal.DirectCheck}
|
||||
|
||||
err = s.Send(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -498,7 +449,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
//nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
||||
fmt.Sprintf("%s:%d", e.wgInterface.Address.IP.String(), nbssh.DefaultSSHPort))
|
||||
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -531,8 +482,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface.Address.String() != conf.Address {
|
||||
oldAddr := e.wgInterface.Address.String()
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
oldAddr := e.wgInterface.Address().String()
|
||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||
err := e.wgInterface.UpdateAddr(conf.Address)
|
||||
if err != nil {
|
||||
@@ -549,10 +500,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
e.statusRecorder.UpdateLocalPeerState(nbstatus.LocalPeerState{
|
||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||
IP: e.config.WgAddr,
|
||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: conf.GetFqdn(),
|
||||
})
|
||||
|
||||
@@ -634,6 +585,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
|
||||
// cleanup request, most likely our peer has been deleted
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
err := e.removeAllPeers()
|
||||
@@ -681,6 +634,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
@@ -750,6 +704,21 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
return dnsUpdate
|
||||
}
|
||||
|
||||
func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
replacement := make([]peer.State, len(offlinePeers))
|
||||
for i, offlinePeer := range offlinePeers {
|
||||
log.Debugf("added offline peer %s", offlinePeer.Fqdn)
|
||||
replacement[i] = peer.State{
|
||||
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||
PubKey: offlinePeer.GetWgPubKey(),
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusDisconnected,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
}
|
||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||
}
|
||||
|
||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
@@ -860,9 +829,10 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
ProxyConfig: proxyConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder)
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -887,6 +857,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
peerConn.SetSignalCandidate(signalCandidate)
|
||||
peerConn.SetSignalOffer(signalOffer)
|
||||
peerConn.SetSignalAnswer(signalAnswer)
|
||||
peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
|
||||
return sendSignal(message, e.signal)
|
||||
})
|
||||
|
||||
return peerConn, nil
|
||||
}
|
||||
@@ -910,6 +883,9 @@ func (e *Engine) receiveSignalEvents() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||
|
||||
conn.OnRemoteOffer(peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -923,6 +899,9 @@ func (e *Engine) receiveSignalEvents() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||
|
||||
conn.OnRemoteAnswer(peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -938,6 +917,19 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
conn.OnRemoteCandidate(candidate)
|
||||
case sProto.Body_MODE:
|
||||
protoMode := msg.GetBody().GetMode()
|
||||
if protoMode == nil {
|
||||
return fmt.Errorf("received an empty mode message")
|
||||
}
|
||||
|
||||
err := conn.OnModeMessage(peer.ModeMessage{
|
||||
Direct: protoMode.GetDirect(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed processing a mode message -> %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -964,6 +956,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
var external, internal string
|
||||
var externalIP, internalIP net.IP
|
||||
var err error
|
||||
|
||||
split := strings.Split(mapping, "/")
|
||||
if len(split) > 2 {
|
||||
log.Warnf("ignoring invalid external mapping '%s', too many delimiters", mapping)
|
||||
@@ -988,7 +981,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
external = split[0]
|
||||
externalIP = net.ParseIP(external)
|
||||
if externalIP == nil {
|
||||
log.Warnf("invalid external IP, ignoring external IP mapping '%s'", mapping)
|
||||
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
|
||||
break
|
||||
}
|
||||
if externalIP != nil {
|
||||
@@ -1007,6 +1000,49 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
return mappedIPs
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMux != nil {
|
||||
if err := e.udpMux.Close(); err != nil {
|
||||
log.Debugf("close udp mux: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxConn != nil {
|
||||
if err := e.udpMuxConn.Close(); err != nil {
|
||||
log.Debugf("close udp mux connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.udpMuxConnSrflx != nil {
|
||||
if err := e.udpMuxConnSrflx.Close(); err != nil {
|
||||
log.Debugf("close server reflexive udp mux connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
11
client/internal/engine_stdnet.go
Normal file
11
client/internal/engine_stdnet.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
||||
}
|
||||
7
client/internal/engine_stdnet_android.go
Normal file
7
client/internal/engine_stdnet_android.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package internal
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||
}
|
||||
@@ -3,14 +3,8 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -21,18 +15,29 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmt "github.com/netbirdio/netbird/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -203,12 +208,24 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
conn, err := net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -387,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -437,7 +454,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
default:
|
||||
}
|
||||
|
||||
if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 {
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -545,8 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@@ -710,8 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, nbstatus.NewRecorder())
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
@@ -777,15 +802,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sport := 10010
|
||||
sigServer, err := startSignal(sport)
|
||||
sigServer, signalAddr, err := startSignal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mport := 33081
|
||||
mgmtServer, err := startManagement(mport, dir)
|
||||
mgmtServer, mgmtAddr, err := startManagement(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -803,7 +826,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mport, sport)
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
@@ -846,7 +869,7 @@ loop:
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected = totalConnected + len(engine.GetConnectedPeers())
|
||||
totalConnected = totalConnected + getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
@@ -857,7 +880,7 @@ loop:
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name, n)
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
@@ -869,16 +892,84 @@ loop:
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mport int, sport int) (*Engine, error) {
|
||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
ifaceList, err := net.Interfaces()
|
||||
if err != nil {
|
||||
t.Fatalf("could get the interface list, got error: %s", err)
|
||||
}
|
||||
|
||||
var testingIP string
|
||||
var testingInterface string
|
||||
|
||||
for _, iface := range ifaceList {
|
||||
addrList, err := iface.Addrs()
|
||||
if err != nil {
|
||||
t.Fatalf("could get the addr list, got error: %s", err)
|
||||
}
|
||||
for _, addr := range addrList {
|
||||
prefix := netip.MustParsePrefix(addr.String())
|
||||
if prefix.Addr().Is4() && !prefix.Addr().IsLoopback() {
|
||||
testingIP = prefix.Addr().String()
|
||||
testingInterface = iface.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputMapList []string
|
||||
inputBlacklistInterface []string
|
||||
expectedOutput []string
|
||||
}{
|
||||
{
|
||||
name: "Parse Valid List Should Be OK",
|
||||
inputBlacklistInterface: defaultInterfaceBlacklist,
|
||||
inputMapList: []string{"1.1.1.1", "8.8.8.8/" + testingInterface},
|
||||
expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP},
|
||||
},
|
||||
{
|
||||
name: "Only Interface name Should Return Nil",
|
||||
inputBlacklistInterface: defaultInterfaceBlacklist,
|
||||
inputMapList: []string{testingInterface},
|
||||
expectedOutput: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid IP Return Nil",
|
||||
inputBlacklistInterface: defaultInterfaceBlacklist,
|
||||
inputMapList: []string{"1.1.1.1000"},
|
||||
expectedOutput: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid Mapping Element Should return Nil",
|
||||
inputBlacklistInterface: defaultInterfaceBlacklist,
|
||||
inputMapList: []string{"1.1.1.1/10.10.10.1/eth0"},
|
||||
expectedOutput: nil,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
IFaceBlackList: testCase.inputBlacklistInterface,
|
||||
NATExternalIPs: testCase.inputMapList,
|
||||
},
|
||||
}
|
||||
parsedList := engine.parseNATExternalIPMappings()
|
||||
require.ElementsMatchf(t, testCase.expectedOutput, parsedList, "elements of parsed list should match expected list")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, fmt.Sprintf("localhost:%d", mport), key, false)
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, fmt.Sprintf("localhost:%d", sport), key, false)
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -909,13 +1000,13 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, nbstatus.NewRecorder()), nil
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
}
|
||||
|
||||
func startSignal(port int) (*grpc.Server, error) {
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
@@ -928,10 +1019,10 @@ func startSignal(port int) (*grpc.Server, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||
func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
@@ -943,9 +1034,9 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, err := server.NewFileStore(config.Datadir)
|
||||
@@ -953,14 +1044,19 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", nil
|
||||
}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
@@ -969,5 +1065,25 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, conn := range e.peerConns {
|
||||
if conn.Status() == peer.StatusConnected {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerConns)
|
||||
}
|
||||
|
||||
@@ -2,37 +2,26 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
||||
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", config.ManagementURL.String(), err)
|
||||
return err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
@@ -42,47 +31,84 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed logging-in peer on Management Service : %v", err)
|
||||
return err
|
||||
}
|
||||
log.Infof("peer has successfully logged-in to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close the Management service client: %v", err)
|
||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
||||
func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
loginResp, err := client.Login(serverPublicKey, sysInfo, pubSSHKey)
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
log.Debugf("peer registration required")
|
||||
return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken, pubSSHKey)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||
return serverKey, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
@@ -105,3 +131,31 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
13
client/internal/mobile_dependency.go
Normal file
13
client/internal/mobile_dependency.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// MobileDependency collect all dependencies for mobile platform
|
||||
type MobileDependency struct {
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
Routes []string
|
||||
}
|
||||
29
client/internal/mobile_dependency_android.go
Normal file
29
client/internal/mobile_dependency_android.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
)
|
||||
|
||||
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||
md := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: ifaceDiscover,
|
||||
}
|
||||
err := md.readMap(mgmClient)
|
||||
return md, err
|
||||
}
|
||||
|
||||
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
|
||||
routes, err := mgmClient.GetRoutes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Routes = make([]string, len(routes))
|
||||
for i, r := range routes {
|
||||
d.Routes[i] = r.GetNetwork()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
13
client/internal/mobile_dependency_nonandroid.go
Normal file
13
client/internal/mobile_dependency_nonandroid.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
)
|
||||
|
||||
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
|
||||
return MobileDependency{}, nil
|
||||
}
|
||||
@@ -35,15 +35,6 @@ type DeviceAuthInfo struct {
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
const (
|
||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
@@ -52,14 +43,7 @@ const (
|
||||
|
||||
// Hosted client
|
||||
type Hosted struct {
|
||||
// Hosted API Audience for validation
|
||||
Audience string
|
||||
// Hosted Native application client id
|
||||
ClientID string
|
||||
// TokenEndpoint to request access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint to request device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
providerConfig ProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
@@ -68,6 +52,7 @@ type Hosted struct {
|
||||
type RequestDeviceCodePayload struct {
|
||||
Audience string `json:"audience"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
@@ -90,8 +75,26 @@ type Claims struct {
|
||||
Audience interface{} `json:"aud"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
UseIDToken bool `json:"-"`
|
||||
}
|
||||
|
||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||
func (t TokenInfo) GetTokenToUse() string {
|
||||
if t.UseIDToken {
|
||||
return t.IDToken
|
||||
}
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewHostedDeviceFlow returns an Hosted OAuth client
|
||||
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
|
||||
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -101,25 +104,23 @@ func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string,
|
||||
}
|
||||
|
||||
return &Hosted{
|
||||
Audience: audience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: tokenEndpoint,
|
||||
HTTPClient: httpClient,
|
||||
DeviceAuthEndpoint: deviceAuthEndpoint,
|
||||
providerConfig: config,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||
return h.ClientID
|
||||
return h.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.ClientID)
|
||||
form.Add("audience", h.Audience)
|
||||
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
|
||||
form.Add("client_id", h.providerConfig.ClientID)
|
||||
form.Add("audience", h.providerConfig.Audience)
|
||||
form.Add("scope", h.providerConfig.Scope)
|
||||
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||
@@ -152,10 +153,10 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
|
||||
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.ClientID)
|
||||
form.Add("client_id", h.providerConfig.ClientID)
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", info.DeviceCode)
|
||||
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||
}
|
||||
@@ -220,18 +221,20 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
UseIDToken: h.providerConfig.UseIDToken,
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,14 +3,15 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -59,9 +60,11 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
|
||||
expectedAudience := "ok"
|
||||
expectedClientID := "bla"
|
||||
expectedScope := "openid"
|
||||
form := url.Values{}
|
||||
form.Add("audience", expectedAudience)
|
||||
form.Add("client_id", expectedClientID)
|
||||
form.Add("scope", expectedScope)
|
||||
expectPayload := form.Encode()
|
||||
|
||||
testCase1 := test{
|
||||
@@ -111,11 +114,15 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
HTTPClient: &httpClient,
|
||||
providerConfig: ProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||
@@ -272,12 +279,15 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
providerConfig: ProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -2,18 +2,21 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/pion/ice/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// ConnConfig is a peer Connection configuration
|
||||
@@ -42,6 +45,9 @@ type ConnConfig struct {
|
||||
LocalWgPort int
|
||||
|
||||
NATExternalIPs []string
|
||||
|
||||
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
|
||||
UserspaceBind bool
|
||||
}
|
||||
|
||||
// OfferAnswer represents a session establishment offer or answer
|
||||
@@ -69,8 +75,9 @@ type Conn struct {
|
||||
// signalCandidate is a handler function to signal remote peer about local connection candidate
|
||||
signalCandidate func(candidate ice.Candidate) error
|
||||
// signalOffer is a handler function to signal remote peer our connection offer (credentials)
|
||||
signalOffer func(OfferAnswer) error
|
||||
signalAnswer func(OfferAnswer) error
|
||||
signalOffer func(OfferAnswer) error
|
||||
signalAnswer func(OfferAnswer) error
|
||||
sendSignalMessage func(message *sProto.Message) error
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
@@ -83,9 +90,25 @@ type Conn struct {
|
||||
agent *ice.Agent
|
||||
status ConnStatus
|
||||
|
||||
statusRecorder *nbStatus.Status
|
||||
statusRecorder *Status
|
||||
|
||||
proxy proxy.Proxy
|
||||
proxy proxy.Proxy
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
adapter iface.TunAdapter
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
type meta struct {
|
||||
protoSupport signal.FeaturesSupport
|
||||
}
|
||||
|
||||
// ModeMessage represents a connection mode chosen by the peer
|
||||
type ModeMessage struct {
|
||||
// Direct indicates that it decided to use a direct connection
|
||||
Direct bool
|
||||
}
|
||||
|
||||
// GetConf returns the connection config
|
||||
@@ -100,7 +123,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||
return &Conn{
|
||||
config: config,
|
||||
mu: sync.Mutex{},
|
||||
@@ -109,53 +132,34 @@ func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error)
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
remoteModeCh: make(chan ModeMessage, 1),
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
|
||||
// to avoid building tunnel over them
|
||||
func interfaceFilter(blackList []string) func(string) bool {
|
||||
|
||||
return func(iFace string) bool {
|
||||
for _, s := range blackList {
|
||||
if strings.HasPrefix(iFace, s) {
|
||||
log.Debugf("ignoring interface %s - it is not allowed", iFace)
|
||||
return false
|
||||
}
|
||||
}
|
||||
// look for unlisted WireGuard interfaces
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
log.Debugf("trying to create a wgctrl client failed with: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := wg.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = wg.Device(iFace)
|
||||
return err != nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) reCreateAgent() error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
failedTimeout := 6 * time.Second
|
||||
|
||||
var err error
|
||||
transportNet, err := conn.newStdNet()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
Urls: conn.config.StunTurn,
|
||||
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
|
||||
FailedTimeout: &failedTimeout,
|
||||
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
|
||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||
UDPMux: conn.config.UDPMux,
|
||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||
Net: transportNet,
|
||||
}
|
||||
|
||||
if conn.config.DisableIPv6Discovery {
|
||||
@@ -192,11 +196,11 @@ func (conn *Conn) reCreateAgent() error {
|
||||
func (conn *Conn) Open() error {
|
||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||
|
||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
|
||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState.ConnStatus = conn.status.String()
|
||||
peerState.ConnStatus = conn.status
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -252,9 +256,9 @@ func (conn *Conn) Open() error {
|
||||
defer conn.notifyDisconnected()
|
||||
conn.mu.Unlock()
|
||||
|
||||
peerState = nbStatus.PeerState{PubKey: conn.config.Key}
|
||||
peerState = State{PubKey: conn.config.Key}
|
||||
|
||||
peerState.ConnStatus = conn.status.String()
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -291,7 +295,7 @@ func (conn *Conn) Open() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if conn.proxy.Type() == proxy.TypeNoProxy {
|
||||
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
||||
// direct Wireguard connection
|
||||
@@ -312,39 +316,82 @@ func (conn *Conn) Open() error {
|
||||
}
|
||||
|
||||
// useProxy determines whether a direct connection (without a go proxy) is possible
|
||||
// There are 3 cases: one of the peers has a public IP or both peers are in the same private network
|
||||
//
|
||||
// There are 3 cases:
|
||||
//
|
||||
// * When neither candidate is from hard nat and one of the peers has a public IP
|
||||
//
|
||||
// * both peers are in the same private network
|
||||
//
|
||||
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
|
||||
//
|
||||
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
|
||||
func shouldUseProxy(pair *ice.CandidatePair) bool {
|
||||
remoteIP := net.ParseIP(pair.Remote.Address())
|
||||
myIp := net.ParseIP(pair.Local.Address())
|
||||
remoteIsPublic := IsPublicIP(remoteIP)
|
||||
myIsPublic := IsPublicIP(myIp)
|
||||
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
|
||||
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
return true
|
||||
}
|
||||
|
||||
//one of the hosts has a public IP
|
||||
if remoteIsPublic && pair.Remote.Type() == ice.CandidateTypeHost {
|
||||
return false
|
||||
}
|
||||
if myIsPublic && pair.Local.Type() == ice.CandidateTypeHost {
|
||||
if !isRelayCandidate(pair.Local) && userspaceBind {
|
||||
log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed")
|
||||
return false
|
||||
}
|
||||
|
||||
if pair.Local.Type() == ice.CandidateTypeHost && pair.Remote.Type() == ice.CandidateTypeHost {
|
||||
if !remoteIsPublic && !myIsPublic {
|
||||
//both hosts are in the same private network
|
||||
return false
|
||||
}
|
||||
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
|
||||
log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP")
|
||||
return false
|
||||
}
|
||||
|
||||
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
|
||||
log.Debugf("shouldn't use proxy because the remote peer is not behind a hard NAT and the local one has a public IP")
|
||||
return false
|
||||
}
|
||||
|
||||
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
|
||||
log.Debugf("shouldn't use proxy because peers are in the same private /16 network")
|
||||
return false
|
||||
}
|
||||
|
||||
if (isPeerReflexiveCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) ||
|
||||
isHostCandidateWithPrivateIP(pair.Local) && isPeerReflexiveCandidateWithPrivateIP(pair.Remote)) && isSameNetworkPrefix(pair) {
|
||||
log.Debugf("shouldn't use proxy because peers are in the same private /16 network and one peer is peer reflexive")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsPublicIP indicates whether IP is public or not.
|
||||
func IsPublicIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
||||
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
|
||||
|
||||
localIP := net.ParseIP(pair.Local.Address())
|
||||
remoteIP := net.ParseIP(pair.Remote.Address())
|
||||
if localIP == nil || remoteIP == nil {
|
||||
return false
|
||||
}
|
||||
// only consider /16 networks
|
||||
mask := net.IPMask{255, 255, 0, 0}
|
||||
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
func isHardNATCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
|
||||
}
|
||||
|
||||
func isHostCandidateWithPublicIP(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeHost && isPublicIP(candidate.Address())
|
||||
}
|
||||
|
||||
func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
|
||||
}
|
||||
|
||||
func isPeerReflexiveCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypePeerReflexive && !isPublicIP(candidate.Address())
|
||||
}
|
||||
|
||||
func isPublicIP(address string) bool {
|
||||
ip := net.ParseIP(address)
|
||||
if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -361,16 +408,8 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
||||
useProxy := shouldUseProxy(pair)
|
||||
var p proxy.Proxy
|
||||
if useProxy {
|
||||
p = proxy.NewWireguardProxy(conn.config.ProxyConfig)
|
||||
peerState.Direct = false
|
||||
} else {
|
||||
p = proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
peerState.Direct = true
|
||||
}
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
p := conn.getProxyWithMessageExchange(pair, remoteWgPort)
|
||||
conn.proxy = p
|
||||
err = p.Start(remoteConn)
|
||||
if err != nil {
|
||||
@@ -379,13 +418,14 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
|
||||
conn.status = StatusConnected
|
||||
|
||||
peerState.ConnStatus = conn.status.String()
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
}
|
||||
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy || p.Type() == proxy.TypeNoProxy
|
||||
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -395,6 +435,65 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
||||
useProxy := shouldUseProxy(pair, conn.config.UserspaceBind)
|
||||
localDirectMode := !useProxy
|
||||
remoteDirectMode := localDirectMode
|
||||
|
||||
if conn.meta.protoSupport.DirectCheck {
|
||||
go conn.sendLocalDirectMode(localDirectMode)
|
||||
// will block until message received or timeout
|
||||
remoteDirectMode = conn.receiveRemoteDirectMode()
|
||||
}
|
||||
|
||||
if conn.config.UserspaceBind && localDirectMode {
|
||||
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
if localDirectMode && remoteDirectMode {
|
||||
return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
||||
}
|
||||
|
||||
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
|
||||
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||
}
|
||||
|
||||
func (conn *Conn) sendLocalDirectMode(localMode bool) {
|
||||
// todo what happens when we couldn't deliver this message?
|
||||
// we could retry, etc but there is no guarantee
|
||||
|
||||
err := conn.sendSignalMessage(&sProto.Message{
|
||||
Key: conn.config.LocalKey,
|
||||
RemoteKey: conn.config.Key,
|
||||
Body: &sProto.Body{
|
||||
Type: sProto.Body_MODE,
|
||||
Mode: &sProto.Mode{
|
||||
Direct: &localMode,
|
||||
},
|
||||
NetBirdVersion: version.NetbirdVersion(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to send local proxy mode to remote peer %s, error: %s", conn.config.Key, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) receiveRemoteDirectMode() bool {
|
||||
timeout := time.Second
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case receivedMSG := <-conn.remoteModeCh:
|
||||
return receivedMSG.Direct
|
||||
case <-timer.C:
|
||||
// we didn't receive a message from remote so we assume that it supports the direct mode to keep the old behaviour
|
||||
log.Debugf("timeout after %s while waiting for remote direct mode message from remote peer %s",
|
||||
timeout, conn.config.Key)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup closes all open resources and sets status to StatusDisconnected
|
||||
func (conn *Conn) cleanup() error {
|
||||
log.Debugf("trying to cleanup %s", conn.config.Key)
|
||||
@@ -424,8 +523,8 @@ func (conn *Conn) cleanup() error {
|
||||
|
||||
conn.status = StatusDisconnected
|
||||
|
||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
||||
peerState.ConnStatus = conn.status.String()
|
||||
peerState := State{PubKey: conn.config.Key}
|
||||
peerState.ConnStatus = conn.status
|
||||
peerState.ConnStatusUpdate = time.Now()
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
@@ -455,6 +554,11 @@ func (conn *Conn) SetSignalCandidate(handler func(candidate ice.Candidate) error
|
||||
conn.signalCandidate = handler
|
||||
}
|
||||
|
||||
// SetSendSignalMessage sets a handler function to be triggered by Conn when there is new message to send via signal
|
||||
func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) error) {
|
||||
conn.sendSignalMessage = handler
|
||||
}
|
||||
|
||||
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
|
||||
// and then signals them to the remote peer
|
||||
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
||||
@@ -496,7 +600,7 @@ func (conn *Conn) sendAnswer() error {
|
||||
err = conn.signalAnswer(OfferAnswer{
|
||||
IceCredentials: IceCredentials{localUFrag, localPwd},
|
||||
WgListenPort: conn.config.LocalWgPort,
|
||||
Version: system.NetbirdVersion(),
|
||||
Version: version.NetbirdVersion(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -517,7 +621,7 @@ func (conn *Conn) sendOffer() error {
|
||||
err = conn.signalOffer(OfferAnswer{
|
||||
IceCredentials: IceCredentials{localUFrag, localPwd},
|
||||
WgListenPort: conn.config.LocalWgPort,
|
||||
Version: system.NetbirdVersion(),
|
||||
Version: version.NetbirdVersion(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -609,3 +713,19 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
||||
func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
// OnModeMessage unmarshall the payload message and send it to the mode message channel
|
||||
func (conn *Conn) OnModeMessage(message ModeMessage) error {
|
||||
select {
|
||||
case conn.remoteModeCh <- message:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unable to process mode message: channel busy")
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
||||
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
||||
protoSupport := signal.ParseFeaturesSupported(support)
|
||||
conn.meta.protoSupport = protoSupport
|
||||
}
|
||||
|
||||
29
client/internal/peer/conn_status.go
Normal file
29
client/internal/peer/conn_status.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package peer
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
const (
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
StatusConnected ConnStatus = iota
|
||||
// StatusConnecting indicate the peer is in connecting state
|
||||
StatusConnecting
|
||||
// StatusDisconnected indicate the peer is in disconnected state
|
||||
StatusDisconnected
|
||||
)
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int
|
||||
|
||||
func (s ConnStatus) String() string {
|
||||
switch s {
|
||||
case StatusConnecting:
|
||||
return "Connecting"
|
||||
case StatusConnected:
|
||||
return "Connected"
|
||||
case StatusDisconnected:
|
||||
return "Disconnected"
|
||||
default:
|
||||
log.Errorf("unknown status: %d", s)
|
||||
return "INVALID_PEER_CONNECTION_STATUS"
|
||||
}
|
||||
}
|
||||
27
client/internal/peer/conn_status_test.go
Normal file
27
client/internal/peer/conn_status_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"github.com/magiconair/properties/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnStatus_String(t *testing.T) {
|
||||
|
||||
tables := []struct {
|
||||
name string
|
||||
status ConnStatus
|
||||
want string
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, "Connected"},
|
||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
t.Run(table.name, func(t *testing.T) {
|
||||
got := table.status.String()
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,14 +1,19 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/pion/ice/v2"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/pion/ice/v2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
sproto "github.com/netbirdio/netbird/signal/proto"
|
||||
)
|
||||
|
||||
var connConf = ConnConfig{
|
||||
@@ -25,7 +30,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
||||
"Tailscale", "tailscale"}
|
||||
|
||||
filter := interfaceFilter(ignore)
|
||||
filter := stdnet.InterfaceFilter(ignore)
|
||||
|
||||
for _, s := range ignore {
|
||||
assert.Equal(t, filter(s), false)
|
||||
@@ -34,7 +39,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
conn, err := NewConn(connConf, nil)
|
||||
conn, err := NewConn(connConf, nil, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -46,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -80,7 +85,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -113,7 +118,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -140,7 +145,7 @@ func TestConn_Status(t *testing.T) {
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -165,3 +170,310 @@ func TestConn_Close(t *testing.T) {
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type mockICECandidate struct {
|
||||
ice.Candidate
|
||||
AddressFunc func() string
|
||||
TypeFunc func() ice.CandidateType
|
||||
}
|
||||
|
||||
// Address mocks and overwrite ice.Candidate Address method
|
||||
func (m *mockICECandidate) Address() string {
|
||||
if m.AddressFunc != nil {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Type mocks and overwrite ice.Candidate Type method
|
||||
func (m *mockICECandidate) Type() ice.CandidateType {
|
||||
if m.TypeFunc != nil {
|
||||
return m.TypeFunc()
|
||||
}
|
||||
return ice.CandidateTypeUnspecified
|
||||
}
|
||||
|
||||
func TestConn_ShouldUseProxy(t *testing.T) {
|
||||
publicHostCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "8.8.8.8"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeHost
|
||||
},
|
||||
}
|
||||
privateHostCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "10.0.0.1"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeHost
|
||||
},
|
||||
}
|
||||
|
||||
srflxCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeServerReflexive
|
||||
},
|
||||
}
|
||||
|
||||
prflxCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
},
|
||||
}
|
||||
|
||||
relayCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeRelay
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
candatePair *ice.CandidatePair
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Use Proxy When Local Candidate Is Relay",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: relayCandidate,
|
||||
Remote: privateHostCandidate,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Use Proxy When Remote Candidate Is Relay",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: privateHostCandidate,
|
||||
Remote: relayCandidate,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Use Proxy When Local Candidate Is Peer Reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: prflxCandidate,
|
||||
Remote: privateHostCandidate,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Use Proxy When Remote Candidate Is Peer Reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: privateHostCandidate,
|
||||
Remote: prflxCandidate,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Local Candidate Is Public And Remote Is Private",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: publicHostCandidate,
|
||||
Remote: privateHostCandidate,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Remote Candidate Is Public And Local Is Private",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: privateHostCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Local Candidate is Public And Remote Is Server Reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: publicHostCandidate,
|
||||
Remote: srflxCandidate,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Remote Candidate is Public And Local Is Server Reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: srflxCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Both Candidates Are Public",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: publicHostCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Both Candidates Are Private",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: privateHostCandidate,
|
||||
Remote: privateHostCandidate,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Don't Use Proxy When Both Candidates are in private network and one is peer reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.102.168"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeHost
|
||||
}},
|
||||
Remote: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.101.96"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
}},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Should Use Proxy When Both Candidates are in private network and both are peer reflexive",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.102.168"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
}},
|
||||
Remote: &mockICECandidate{AddressFunc: func() string {
|
||||
return "10.16.101.96"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypePeerReflexive
|
||||
}},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := shouldUseProxy(testCase.candatePair, false)
|
||||
if result != testCase.expected {
|
||||
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
publicHostCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "8.8.8.8"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeHost
|
||||
},
|
||||
}
|
||||
relayCandidate := &mockICECandidate{
|
||||
AddressFunc: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
TypeFunc: func() ice.CandidateType {
|
||||
return ice.CandidateTypeRelay
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
candatePair *ice.CandidatePair
|
||||
inputDirectModeSupport bool
|
||||
inputRemoteModeMessage bool
|
||||
expected proxy.Type
|
||||
}{
|
||||
{
|
||||
name: "Should Result In Using Wireguard Proxy When Local Eval Is Use Proxy",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: relayCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
inputDirectModeSupport: true,
|
||||
inputRemoteModeMessage: true,
|
||||
expected: proxy.TypeWireGuard,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: publicHostCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
inputDirectModeSupport: true,
|
||||
inputRemoteModeMessage: false,
|
||||
expected: proxy.TypeWireGuard,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: relayCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
inputDirectModeSupport: false,
|
||||
inputRemoteModeMessage: false,
|
||||
expected: proxy.TypeWireGuard,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: publicHostCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
inputDirectModeSupport: false,
|
||||
inputRemoteModeMessage: false,
|
||||
expected: proxy.TypeDirectNoProxy,
|
||||
},
|
||||
{
|
||||
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
|
||||
candatePair: &ice.CandidatePair{
|
||||
Local: publicHostCandidate,
|
||||
Remote: publicHostCandidate,
|
||||
},
|
||||
inputDirectModeSupport: true,
|
||||
inputRemoteModeMessage: true,
|
||||
expected: proxy.TypeDirectNoProxy,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
g := errgroup.Group{}
|
||||
conn, err := NewConn(connConf, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.meta.protoSupport.DirectCheck = testCase.inputDirectModeSupport
|
||||
conn.SetSendSignalMessage(func(message *sproto.Message) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return conn.OnModeMessage(ModeMessage{
|
||||
Direct: testCase.inputRemoteModeMessage,
|
||||
})
|
||||
})
|
||||
|
||||
resultProxy := conn.getProxyWithMessageExchange(testCase.candatePair, 1000)
|
||||
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if resultProxy.Type() != testCase.expected {
|
||||
t.Errorf("result didn't match expected value: Expected: %s, Got: %s", testCase.expected, resultProxy.Type())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
11
client/internal/peer/listener.go
Normal file
11
client/internal/peer/listener.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package peer
|
||||
|
||||
// Listener is a callback type about the NetBird network connection state
|
||||
type Listener interface {
|
||||
OnConnected()
|
||||
OnDisconnected()
|
||||
OnConnecting()
|
||||
OnDisconnecting()
|
||||
OnAddressChanged(string, string)
|
||||
OnPeersListChanged(int)
|
||||
}
|
||||
142
client/internal/peer/notifier.go
Normal file
142
client/internal/peer/notifier.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
stateDisconnected = iota
|
||||
stateConnected
|
||||
stateConnecting
|
||||
stateDisconnecting
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
serverStateLock sync.Mutex
|
||||
listenersLock sync.Mutex
|
||||
listener Listener
|
||||
currentClientState bool
|
||||
lastNotification int
|
||||
}
|
||||
|
||||
func newNotifier() *notifier {
|
||||
return ¬ifier{}
|
||||
}
|
||||
|
||||
func (n *notifier) setListener(listener Listener) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
n.serverStateLock.Lock()
|
||||
n.notifyListener(listener, n.lastNotification)
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *notifier) removeListener() {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
n.listener = nil
|
||||
}
|
||||
|
||||
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
|
||||
calculatedState := n.calculateState(mgmState, signalState)
|
||||
|
||||
if !n.isServerStateChanged(calculatedState) {
|
||||
return
|
||||
}
|
||||
|
||||
n.lastNotification = calculatedState
|
||||
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStart() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = true
|
||||
n.lastNotification = stateConnected
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStop() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = stateDisconnected
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientTearDown() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = stateDisconnecting
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) isServerStateChanged(newState int) bool {
|
||||
return n.lastNotification != newState
|
||||
}
|
||||
|
||||
func (n *notifier) notify(state int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
n.notifyListener(n.listener, state)
|
||||
}
|
||||
|
||||
func (n *notifier) notifyListener(l Listener, state int) {
|
||||
go func() {
|
||||
switch state {
|
||||
case stateDisconnected:
|
||||
l.OnDisconnected()
|
||||
case stateConnected:
|
||||
l.OnConnected()
|
||||
case stateConnecting:
|
||||
l.OnConnecting()
|
||||
case stateDisconnecting:
|
||||
l.OnDisconnecting()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
||||
if managementConn && signalConn {
|
||||
return stateConnected
|
||||
}
|
||||
|
||||
if !managementConn && !signalConn {
|
||||
return stateDisconnected
|
||||
}
|
||||
|
||||
if n.lastNotification == stateDisconnecting {
|
||||
return stateDisconnecting
|
||||
}
|
||||
|
||||
return stateConnecting
|
||||
}
|
||||
|
||||
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
n.listener.OnPeersListChanged(numOfPeers)
|
||||
}
|
||||
|
||||
func (n *notifier) localAddressChanged(fqdn, address string) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
n.listener.OnAddressChanged(fqdn, address)
|
||||
}
|
||||
97
client/internal/peer/notifier_test.go
Normal file
97
client/internal/peer/notifier_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mocListener struct {
|
||||
lastState int
|
||||
wg sync.WaitGroup
|
||||
peers int
|
||||
}
|
||||
|
||||
func (l *mocListener) OnConnected() {
|
||||
l.lastState = stateConnected
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnDisconnected() {
|
||||
l.lastState = stateDisconnected
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnConnecting() {
|
||||
l.lastState = stateConnecting
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnDisconnecting() {
|
||||
l.lastState = stateDisconnecting
|
||||
l.wg.Done()
|
||||
}
|
||||
|
||||
func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||
|
||||
}
|
||||
func (l *mocListener) OnPeersListChanged(size int) {
|
||||
l.peers = size
|
||||
}
|
||||
|
||||
func (l *mocListener) setWaiter() {
|
||||
l.wg.Add(1)
|
||||
}
|
||||
|
||||
func (l *mocListener) wait() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func Test_notifier_serverState(t *testing.T) {
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
expected int
|
||||
mgmState bool
|
||||
signalState bool
|
||||
}
|
||||
scenarios := []scenario{
|
||||
{"connected", stateConnected, true, true},
|
||||
{"mgm down", stateConnecting, false, true},
|
||||
{"signal down", stateConnecting, true, false},
|
||||
{"disconnected", stateDisconnected, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range scenarios {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := newNotifier()
|
||||
n.updateServerStates(tt.mgmState, tt.signalState)
|
||||
if n.lastNotification != tt.expected {
|
||||
t.Errorf("invalid serverstate: %d, expected: %d", n.lastNotification, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_notifier_SetListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
listener.wait()
|
||||
if listener.lastState != n.lastNotification {
|
||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_notifier_RemoveListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
n.removeListener()
|
||||
n.peerListChanged(1)
|
||||
|
||||
if listener.peers != 0 {
|
||||
t.Errorf("invalid state: %d", listener.peers)
|
||||
}
|
||||
}
|
||||
@@ -1,25 +1,316 @@
|
||||
package peer
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ConnStatus int
|
||||
// State contains the latest state of a peer
|
||||
type State struct {
|
||||
IP string
|
||||
PubKey string
|
||||
FQDN string
|
||||
ConnStatus ConnStatus
|
||||
ConnStatusUpdate time.Time
|
||||
Relayed bool
|
||||
Direct bool
|
||||
LocalIceCandidateType string
|
||||
RemoteIceCandidateType string
|
||||
}
|
||||
|
||||
func (s ConnStatus) String() string {
|
||||
switch s {
|
||||
case StatusConnecting:
|
||||
return "Connecting"
|
||||
case StatusConnected:
|
||||
return "Connected"
|
||||
case StatusDisconnected:
|
||||
return "Disconnected"
|
||||
default:
|
||||
log.Errorf("unknown status: %d", s)
|
||||
return "INVALID_PEER_CONNECTION_STATUS"
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
type LocalPeerState struct {
|
||||
IP string
|
||||
PubKey string
|
||||
KernelInterface bool
|
||||
FQDN string
|
||||
}
|
||||
|
||||
// SignalState contains the latest state of a signal connection
|
||||
type SignalState struct {
|
||||
URL string
|
||||
Connected bool
|
||||
}
|
||||
|
||||
// ManagementState contains the latest state of a management connection
|
||||
type ManagementState struct {
|
||||
URL string
|
||||
Connected bool
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
type FullStatus struct {
|
||||
Peers []State
|
||||
ManagementState ManagementState
|
||||
SignalState SignalState
|
||||
LocalPeerState LocalPeerState
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal and management connections
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
managementState bool
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
StatusConnected ConnStatus = iota
|
||||
StatusConnecting
|
||||
StatusDisconnected
|
||||
)
|
||||
// ReplaceOfflinePeers replaces
|
||||
func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.offlinePeers = make([]State, len(replacement))
|
||||
copy(d.offlinePeers, replacement)
|
||||
}
|
||||
|
||||
// AddPeer adds peer to Daemon status map
|
||||
func (d *Status) AddPeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
_, ok := d.peers[peerPubKey]
|
||||
if ok {
|
||||
return errors.New("peer already exist")
|
||||
}
|
||||
d.peers[peerPubKey] = State{PubKey: peerPubKey, ConnStatus: StatusDisconnected}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return State{}, errors.New("peer not found")
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
_, ok := d.peers[peerPubKey]
|
||||
if ok {
|
||||
delete(d.peers, peerPubKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
// UpdatePeerState updates peer status
|
||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[receivedState.PubKey]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
if receivedState.IP != "" {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
peerState.ConnStatus = receivedState.ConnStatus
|
||||
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
||||
peerState.Direct = receivedState.Direct
|
||||
peerState.Relayed = receivedState.Relayed
|
||||
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
||||
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
||||
}
|
||||
|
||||
d.peers[receivedState.PubKey] = peerState
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeerFQDN update peer's state fqdn only
|
||||
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.FQDN = fqdn
|
||||
d.peers[peerPubKey] = peerState
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
ch, found := d.changeNotify[peer]
|
||||
if !found || ch == nil {
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// UpdateLocalPeerState updates local peer status
|
||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = localPeerState
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = LocalPeerState{}
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = false
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = true
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signalAddress = signalURL
|
||||
}
|
||||
|
||||
// UpdateManagementAddress update the address of the management server
|
||||
func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mgmAddress = mgmAddress
|
||||
}
|
||||
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = false
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = true
|
||||
}
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
fullStatus := FullStatus{
|
||||
ManagementState: ManagementState{
|
||||
d.mgmAddress,
|
||||
d.managementState,
|
||||
},
|
||||
SignalState: SignalState{
|
||||
d.signalAddress,
|
||||
d.signalState,
|
||||
},
|
||||
LocalPeerState: d.localPeer,
|
||||
}
|
||||
|
||||
for _, status := range d.peers {
|
||||
fullStatus.Peers = append(fullStatus.Peers, status)
|
||||
}
|
||||
|
||||
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.notifier.clientStart()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.notifier.clientStop()
|
||||
}
|
||||
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.notifier.clientTearDown()
|
||||
}
|
||||
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
func (d *Status) SetConnectionListener(listener Listener) {
|
||||
d.notifier.setListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove the listener from the notifier
|
||||
func (d *Status) RemoveConnectionListener() {
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(len(d.peers))
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||
}
|
||||
|
||||
@@ -1,27 +1,233 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"github.com/magiconair/properties/assert"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnStatus_String(t *testing.T) {
|
||||
func TestAddPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder("https://mgm")
|
||||
err := status.AddPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
tables := []struct {
|
||||
name string
|
||||
status ConnStatus
|
||||
want string
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, "Connected"},
|
||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||
_, exists := status.peers[key]
|
||||
assert.True(t, exists, "value was found")
|
||||
|
||||
err = status.AddPeer(key)
|
||||
|
||||
assert.Error(t, err, "should return error on duplicate")
|
||||
}
|
||||
|
||||
func TestGetPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder("https://mgm")
|
||||
err := status.AddPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
peerStatus, err := status.GetPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error on getting peer")
|
||||
|
||||
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
|
||||
|
||||
_, err = status.GetPeer("non_existing_key")
|
||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||
}
|
||||
|
||||
func TestUpdatePeerState(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
t.Run(table.name, func(t *testing.T) {
|
||||
got := table.status.String()
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
status.peers[key] = peerState
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
state, exists := status.peers[key]
|
||||
assert.True(t, exists, "state should be found")
|
||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
err := status.UpdatePeerFQDN(key, fqdn)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
state, exists := status.peers[key]
|
||||
assert.True(t, exists, "state should be found")
|
||||
assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal")
|
||||
}
|
||||
|
||||
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
ch := status.GetPeerStateChangeNotifier(key)
|
||||
assert.NotNil(t, ch, "channel shouldn't be nil")
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
t.Errorf("channel wasn't closed after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
err := status.RemovePeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
_, exists := status.peers[key]
|
||||
assert.False(t, exists, "state value shouldn't be found")
|
||||
|
||||
err = status.RemovePeer("not existing")
|
||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||
}
|
||||
|
||||
func TestUpdateLocalPeerState(t *testing.T) {
|
||||
localPeerState := LocalPeerState{
|
||||
IP: "10.10.10.10",
|
||||
PubKey: "abc",
|
||||
KernelInterface: false,
|
||||
}
|
||||
status := NewRecorder("https://mgm")
|
||||
|
||||
status.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
assert.Equal(t, localPeerState, status.localPeer, "local peer status should be equal")
|
||||
}
|
||||
|
||||
func TestCleanLocalPeerState(t *testing.T) {
|
||||
emptyLocalPeerState := LocalPeerState{}
|
||||
localPeerState := LocalPeerState{
|
||||
IP: "10.10.10.10",
|
||||
PubKey: "abc",
|
||||
KernelInterface: false,
|
||||
}
|
||||
status := NewRecorder("https://mgm")
|
||||
|
||||
status.localPeer = localPeerState
|
||||
|
||||
status.CleanLocalPeerState()
|
||||
|
||||
assert.Equal(t, emptyLocalPeerState, status.localPeer, "local peer status should be empty")
|
||||
}
|
||||
|
||||
func TestUpdateSignalState(t *testing.T) {
|
||||
url := "https://signal"
|
||||
var tests = []struct {
|
||||
name string
|
||||
connected bool
|
||||
want bool
|
||||
}{
|
||||
{"should mark as connected", true, true},
|
||||
{"should mark as disconnected", false, false},
|
||||
}
|
||||
|
||||
status := NewRecorder("https://mgm")
|
||||
status.UpdateSignalAddress(url)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.connected {
|
||||
status.MarkSignalConnected()
|
||||
} else {
|
||||
status.MarkSignalDisconnected()
|
||||
}
|
||||
assert.Equal(t, test.want, status.signalState, "signal status should be equal")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestUpdateManagementState(t *testing.T) {
|
||||
url := "https://management"
|
||||
var tests = []struct {
|
||||
name string
|
||||
connected bool
|
||||
want bool
|
||||
}{
|
||||
{"should mark as connected", true, true},
|
||||
{"should mark as disconnected", false, false},
|
||||
}
|
||||
|
||||
status := NewRecorder(url)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.connected {
|
||||
status.MarkManagementConnected()
|
||||
} else {
|
||||
status.MarkManagementDisconnected()
|
||||
}
|
||||
assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFullStatus(t *testing.T) {
|
||||
key1 := "abc"
|
||||
key2 := "def"
|
||||
signalAddr := "https://signal"
|
||||
managementState := ManagementState{
|
||||
URL: "https://mgm",
|
||||
Connected: true,
|
||||
}
|
||||
signalState := SignalState{
|
||||
URL: signalAddr,
|
||||
Connected: true,
|
||||
}
|
||||
peerState1 := State{
|
||||
PubKey: key1,
|
||||
}
|
||||
|
||||
peerState2 := State{
|
||||
PubKey: key2,
|
||||
}
|
||||
|
||||
status := NewRecorder("https://mgm")
|
||||
status.UpdateSignalAddress(signalAddr)
|
||||
|
||||
status.managementState = managementState.Connected
|
||||
status.signalState = signalState.Connected
|
||||
status.peers[key1] = peerState1
|
||||
status.peers[key2] = peerState2
|
||||
|
||||
fullStatus := status.GetFullStatus()
|
||||
|
||||
assert.Equal(t, managementState, fullStatus.ManagementState, "management status should be equal")
|
||||
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
||||
assert.ElementsMatch(t, []State{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
||||
}
|
||||
|
||||
11
client/internal/peer/stdnet.go
Normal file
11
client/internal/peer/stdnet.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !android
|
||||
|
||||
package peer
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(conn.config.InterfaceBlackList)
|
||||
}
|
||||
7
client/internal/peer/stdnet_android.go
Normal file
7
client/internal/peer/stdnet_android.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package peer
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
|
||||
}
|
||||
57
client/internal/proxy/direct.go
Normal file
57
client/internal/proxy/direct.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
)
|
||||
|
||||
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
|
||||
// This is possible in either of these cases:
|
||||
// - peers are in the same local network
|
||||
// - one of the peers has a public static IP (host)
|
||||
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
|
||||
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
|
||||
type DirectNoProxy struct {
|
||||
config Config
|
||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
||||
// It is used instead of the hardcoded 51820 port.
|
||||
RemoteWgListenPort int
|
||||
}
|
||||
|
||||
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
|
||||
func NewDirectNoProxy(config Config, remoteWgPort int) *DirectNoProxy {
|
||||
return &DirectNoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
||||
}
|
||||
|
||||
// Close removes peer from the WireGuard interface
|
||||
func (p *DirectNoProxy) Close() error {
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start just updates WireGuard peer with the remote IP and default WireGuard port
|
||||
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
|
||||
|
||||
log.Debugf("using DirectNoProxy while connecting to peer %s", p.config.RemoteKey)
|
||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr.Port = p.RemoteWgListenPort
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type returns the type of this proxy
|
||||
func (p *DirectNoProxy) Type() Type {
|
||||
return TypeDirectNoProxy
|
||||
}
|
||||
@@ -5,24 +5,18 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// NoProxy is used when there is no need for a proxy between ICE and Wireguard.
|
||||
// This is possible in either of these cases:
|
||||
// - peers are in the same local network
|
||||
// - one of the peers has a public static IP (host)
|
||||
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
|
||||
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
|
||||
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
||||
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
||||
type NoProxy struct {
|
||||
config Config
|
||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
||||
// It is used instead of the hardcoded 51820 port.
|
||||
RemoteWgListenPort int
|
||||
}
|
||||
|
||||
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
|
||||
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
|
||||
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
||||
// NewNoProxy creates a new NoProxy with a provided config
|
||||
func NewNoProxy(config Config) *NoProxy {
|
||||
return &NoProxy{config: config}
|
||||
}
|
||||
|
||||
// Close removes peer from the WireGuard interface
|
||||
func (p *NoProxy) Close() error {
|
||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||
if err != nil {
|
||||
@@ -31,23 +25,16 @@ func (p *NoProxy) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start just updates Wireguard peer with the remote IP and default Wireguard port
|
||||
// Start just updates WireGuard peer with the remote address
|
||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
||||
|
||||
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey)
|
||||
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr.Port = p.RemoteWgListenPort
|
||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||
addr, p.config.PreSharedKey)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *NoProxy) Type() Type {
|
||||
|
||||
@@ -13,9 +13,10 @@ const DefaultWgKeepAlive = 25 * time.Second
|
||||
type Type string
|
||||
|
||||
const (
|
||||
TypeNoProxy Type = "NoProxy"
|
||||
TypeWireguard Type = "Wireguard"
|
||||
TypeDummy Type = "Dummy"
|
||||
TypeDirectNoProxy Type = "DirectNoProxy"
|
||||
TypeWireGuard Type = "WireGuard"
|
||||
TypeDummy Type = "Dummy"
|
||||
TypeNoProxy Type = "NoProxy"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// WireguardProxy proxies
|
||||
type WireguardProxy struct {
|
||||
// WireGuardProxy proxies
|
||||
type WireGuardProxy struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
@@ -17,13 +17,13 @@ type WireguardProxy struct {
|
||||
localConn net.Conn
|
||||
}
|
||||
|
||||
func NewWireguardProxy(config Config) *WireguardProxy {
|
||||
p := &WireguardProxy{config: config}
|
||||
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
||||
p := &WireGuardProxy{config: config}
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) updateEndpoint() error {
|
||||
func (p *WireGuardProxy) updateEndpoint() error {
|
||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -38,7 +38,7 @@ func (p *WireguardProxy) updateEndpoint() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) Start(remoteConn net.Conn) error {
|
||||
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
@@ -60,7 +60,7 @@ func (p *WireguardProxy) Start(remoteConn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) Close() error {
|
||||
func (p *WireGuardProxy) Close() error {
|
||||
p.cancel()
|
||||
if c := p.localConn; c != nil {
|
||||
err := p.localConn.Close()
|
||||
@@ -77,7 +77,7 @@ func (p *WireguardProxy) Close() error {
|
||||
|
||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||
// blocks
|
||||
func (p *WireguardProxy) proxyToRemote() {
|
||||
func (p *WireGuardProxy) proxyToRemote() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
@@ -101,7 +101,7 @@ func (p *WireguardProxy) proxyToRemote() {
|
||||
|
||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||
// blocks
|
||||
func (p *WireguardProxy) proxyToLocal() {
|
||||
func (p *WireGuardProxy) proxyToLocal() {
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
@@ -123,6 +123,6 @@ func (p *WireguardProxy) proxyToLocal() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WireguardProxy) Type() Type {
|
||||
return TypeWireguard
|
||||
func (p *WireGuardProxy) Type() Type {
|
||||
return TypeWireGuard
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type routerPeerStatus struct {
|
||||
@@ -26,7 +26,7 @@ type routesUpdate struct {
|
||||
type clientNetwork struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
statusRecorder *status.Status
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
routes map[string]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
@@ -37,7 +37,7 @@ type clientNetwork struct {
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork {
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
@@ -62,7 +62,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
continue
|
||||
}
|
||||
routePeerStatuses[r.ID] = routerPeerStatus{
|
||||
connected: peerStatus.ConnStatus == peer.StatusConnected.String(),
|
||||
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||
relayed: peerStatus.Relayed,
|
||||
direct: peerStatus.Direct,
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
||||
return
|
||||
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil || state.ConnStatus == peer.StatusConnecting.String() {
|
||||
if err != nil || state.ConnStatus == peer.StatusConnecting {
|
||||
continue
|
||||
}
|
||||
peerStateUpdate <- struct{}{}
|
||||
@@ -144,7 +144,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil || state.ConnStatus != peer.StatusConnected.String() {
|
||||
if err != nil || state.ConnStatus != peer.StatusConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String())
|
||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
||||
c.network, err)
|
||||
@@ -201,10 +201,10 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
|
||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
import "github.com/google/nftables"
|
||||
|
||||
const (
|
||||
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
@@ -2,15 +2,15 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// Manager is a route manager interface
|
||||
@@ -25,26 +25,20 @@ type DefaultManager struct {
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRoutes map[string]*route.Route
|
||||
serverRouter *serverRouter
|
||||
statusRecorder *status.Status
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
}
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager {
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
return &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRoutes: make(map[string]*route.Route),
|
||||
serverRouter: &serverRouter{
|
||||
routes: make(map[string]*route.Route),
|
||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||
firewall: NewFirewall(ctx),
|
||||
},
|
||||
serverRouter: newServerRouter(ctx, wgInterface),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
@@ -54,86 +48,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.firewall.CleanRoutingRules()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.serverRoutes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.serverRoutes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.serverRoutes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.serverRoutes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.serverRoutes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.serverRoutes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
@@ -170,7 +85,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < 7 {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||
system.NetbirdVersion(), newRoute.Network)
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
@@ -179,7 +94,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
|
||||
err := m.updateServerRoutes(newServerRoutesMap)
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -187,3 +102,29 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
update := routesUpdate{
|
||||
updateSerial: updateSerial,
|
||||
routes: routes,
|
||||
}
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,14 +3,16 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/status"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
|
||||
@@ -390,14 +392,19 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU)
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
statusRecorder := status.NewRecorder()
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
||||
defer routeManager.Stop()
|
||||
@@ -413,7 +420,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||
|
||||
if testCase.shouldCheckServerRoutes {
|
||||
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
import "github.com/google/nftables"
|
||||
|
||||
const (
|
||||
nftablesTable = "netbird-rt"
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
24
client/internal/routemanager/router_pair.go
Normal file
24
client/internal/routemanager/router_pair.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type routerPair struct {
|
||||
ID string
|
||||
source string
|
||||
destination string
|
||||
masquerade bool
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
return routerPair{
|
||||
ID: route.ID,
|
||||
source: parsed.String(),
|
||||
destination: route.Network.Masked().String(),
|
||||
masquerade: route.Masquerade,
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
routes map[string]*route.Route
|
||||
// best effort to keep net forward configuration as it was
|
||||
netForwardHistoryEnabled bool
|
||||
mux sync.Mutex
|
||||
firewall firewallManager
|
||||
}
|
||||
|
||||
type routerPair struct {
|
||||
ID string
|
||||
source string
|
||||
destination string
|
||||
masquerade bool
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
return routerPair{
|
||||
ID: route.ID,
|
||||
source: parsed.String(),
|
||||
destination: route.Network.Masked().String(),
|
||||
masquerade: route.Masquerade,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.serverRouter.mux.Lock()
|
||||
defer m.serverRouter.mux.Unlock()
|
||||
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.serverRouter.routes, route.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.serverRouter.mux.Lock()
|
||||
defer m.serverRouter.mux.Unlock()
|
||||
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.serverRouter.routes[route.ID] = route
|
||||
return nil
|
||||
}
|
||||
}
|
||||
21
client/internal/routemanager/server_android.go
Normal file
21
client/internal/routemanager/server_android.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||
return &serverRouter{}
|
||||
}
|
||||
|
||||
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *serverRouter) cleanUp() {}
|
||||
120
client/internal/routemanager/server_nonandroid.go
Normal file
120
client/internal/routemanager/server_nonandroid.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[string]*route.Route
|
||||
firewall firewallManager
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||
return &serverRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[string]*route.Route),
|
||||
firewall: NewFirewall(ctx),
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
err := m.firewall.RestoreOrCreateContainers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for routeID := range m.routes {
|
||||
update, found := routesMap[routeID]
|
||||
if !found || !update.IsEqual(m.routes[routeID]) {
|
||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range serverRoutesToRemove {
|
||||
oldRoute := m.routes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.routes, routeID)
|
||||
}
|
||||
|
||||
for id, newRoute := range routesMap {
|
||||
_, found := m.routes[id]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.routes[id] = newRoute
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.routes, route.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.routes[route.ID] = route
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) cleanUp() {
|
||||
m.firewall.CleanRoutingRules()
|
||||
}
|
||||
13
client/internal/routemanager/systemops_android.go
Normal file
13
client/internal/routemanager/systemops_android.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
@@ -62,12 +65,3 @@ func enableIPForwarding() error {
|
||||
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
|
||||
return err
|
||||
}
|
||||
|
||||
func isNetForwardHistoryEnabled() bool {
|
||||
out, err := os.ReadFile(ipv4ForwardingPath)
|
||||
if err != nil {
|
||||
// todo
|
||||
panic(err)
|
||||
}
|
||||
return string(out) == "1"
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var errRouteNotFound = fmt.Errorf("route not found")
|
||||
@@ -3,6 +3,7 @@ package routemanager
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -32,25 +33,29 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String())
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
if testCase.shouldRouteToWireguard {
|
||||
require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
} else {
|
||||
require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
||||
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
||||
}
|
||||
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String())
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
||||
@@ -4,10 +4,11 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||
@@ -34,8 +35,3 @@ func enableIPForwarding() error {
|
||||
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNetForwardHistoryEnabled() bool {
|
||||
log.Infof("check netforwad history is not implemented on %s", runtime.GOOS)
|
||||
return false
|
||||
}
|
||||
|
||||
14
client/internal/stdnet/discover.go
Normal file
14
client/internal/stdnet/discover.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package stdnet
|
||||
|
||||
import "github.com/pion/transport/v2"
|
||||
|
||||
// ExternalIFaceDiscover provide an option for external services (mobile)
|
||||
// to collect network interface information
|
||||
type ExternalIFaceDiscover interface {
|
||||
// IFaces return with the description of the interfaces
|
||||
IFaces() (string, error)
|
||||
}
|
||||
|
||||
type iFaceDiscover interface {
|
||||
iFaces() ([]*transport.Interface, error)
|
||||
}
|
||||
98
client/internal/stdnet/discover_mobile.go
Normal file
98
client/internal/stdnet/discover_mobile.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type mobileIFaceDiscover struct {
|
||||
externalDiscover ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
|
||||
return &mobileIFaceDiscover{
|
||||
externalDiscover: externalDiscover,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
|
||||
ifacesString, err := m.externalDiscover.IFaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
interfaces := m.parseInterfacesString(ifacesString)
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
|
||||
ifs := []*transport.Interface{}
|
||||
|
||||
for _, iface := range strings.Split(interfaces, "\n") {
|
||||
if strings.TrimSpace(iface) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Split(iface, "|")
|
||||
if len(fields) != 2 {
|
||||
log.Warnf("parseInterfacesString: unable to split %q", iface)
|
||||
continue
|
||||
}
|
||||
|
||||
var name string
|
||||
var index, mtu int
|
||||
var up, broadcast, loopback, pointToPoint, multicast bool
|
||||
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
|
||||
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
|
||||
if err != nil {
|
||||
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newIf := net.Interface{
|
||||
Name: name,
|
||||
Index: index,
|
||||
MTU: mtu,
|
||||
}
|
||||
if up {
|
||||
newIf.Flags |= net.FlagUp
|
||||
}
|
||||
if broadcast {
|
||||
newIf.Flags |= net.FlagBroadcast
|
||||
}
|
||||
if loopback {
|
||||
newIf.Flags |= net.FlagLoopback
|
||||
}
|
||||
if pointToPoint {
|
||||
newIf.Flags |= net.FlagPointToPoint
|
||||
}
|
||||
if multicast {
|
||||
newIf.Flags |= net.FlagMulticast
|
||||
}
|
||||
|
||||
ifc := transport.NewInterface(newIf)
|
||||
|
||||
addrs := strings.Trim(fields[1], " \n")
|
||||
foundAddress := false
|
||||
for _, addr := range strings.Split(addrs, " ") {
|
||||
if strings.Contains(addr, "%") {
|
||||
continue
|
||||
}
|
||||
ip, ipNet, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
log.Warnf("%s", err)
|
||||
continue
|
||||
}
|
||||
ipNet.IP = ip
|
||||
ifc.AddAddress(ipNet)
|
||||
foundAddress = true
|
||||
}
|
||||
if foundAddress {
|
||||
ifs = append(ifs, ifc)
|
||||
}
|
||||
}
|
||||
return ifs
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user