mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
238 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e37a337164 | ||
|
|
d7efea74b6 | ||
|
|
b8c46e2654 | ||
|
|
7a46a63a14 | ||
|
|
b6211ad020 | ||
|
|
c2eaf8a1c0 | ||
|
|
dc05102b8f | ||
|
|
d1a323fa9d | ||
|
|
63d211c698 | ||
|
|
0ca06b566a | ||
|
|
cf9e447bf0 | ||
|
|
fdd23d4644 | ||
|
|
5a3ee4f9c4 | ||
|
|
5ffed796c0 | ||
|
|
ab895be4a3 | ||
|
|
96cdcf8e49 | ||
|
|
63f6514be5 | ||
|
|
afece95ae5 | ||
|
|
d78b7e5d93 | ||
|
|
67906f6da5 | ||
|
|
52b5a31058 | ||
|
|
b58094de0f | ||
|
|
456aaf2868 | ||
|
|
d379c25ff5 | ||
|
|
f86ed12cf5 | ||
|
|
5a45f79fec | ||
|
|
e7d063126d | ||
|
|
fb42fedb58 | ||
|
|
9eb1e90bbe | ||
|
|
53fb0a9754 | ||
|
|
70c7543e36 | ||
|
|
d1d01a0611 | ||
|
|
9e8725618e | ||
|
|
a40261ff7e | ||
|
|
89e8540531 | ||
|
|
9f7e13fc87 | ||
|
|
8be6e92563 | ||
|
|
b726b3262d | ||
|
|
125a7a9daf | ||
|
|
9b1a0c2df7 | ||
|
|
1568c8aa91 | ||
|
|
2f5ba96596 | ||
|
|
63568e5e0e | ||
|
|
9c4bf1e899 | ||
|
|
2c01514259 | ||
|
|
e2f27502e4 | ||
|
|
8cf2866a6a | ||
|
|
c99ae6f009 | ||
|
|
8843784312 | ||
|
|
c38d65ef4c | ||
|
|
6d4240a5ae | ||
|
|
52f5101715 | ||
|
|
e2eef4e3fd | ||
|
|
76318f3f06 | ||
|
|
db25ca21a8 | ||
|
|
a8d03d8c91 | ||
|
|
74ff2619d0 | ||
|
|
40bea645e9 | ||
|
|
e7d52beeab | ||
|
|
7a5c6b24ae | ||
|
|
90c2093018 | ||
|
|
06318a15e1 | ||
|
|
eeb38b7ecf | ||
|
|
e59d2317fe | ||
|
|
ee6be58a67 | ||
|
|
a9f5fad625 | ||
|
|
c979a4e9fb | ||
|
|
f2fc0df104 | ||
|
|
87cc53b743 | ||
|
|
7d8a69cc0c | ||
|
|
e4de1d75de | ||
|
|
73e57f17ea | ||
|
|
46f5f148da | ||
|
|
32880c56a4 | ||
|
|
2b90ff8c24 | ||
|
|
b8599f634c | ||
|
|
659110f0d5 | ||
|
|
4ad14cb46b | ||
|
|
3c485dc7a1 | ||
|
|
f7e6cdcbf0 | ||
|
|
af6fdd3af2 | ||
|
|
5781ec7a8e | ||
|
|
1219006a6e | ||
|
|
4791e41004 | ||
|
|
9131069d12 | ||
|
|
26bbc33e7a | ||
|
|
35bc493cc3 | ||
|
|
e26ec0b937 | ||
|
|
a952e7c72f | ||
|
|
22f69d7852 | ||
|
|
b23011fbe8 | ||
|
|
6ad3894a51 | ||
|
|
c81b83b346 | ||
|
|
8c5c6815e0 | ||
|
|
0c470e7838 | ||
|
|
8118d60ffb | ||
|
|
1956ca169e | ||
|
|
830dee1771 | ||
|
|
c08a96770e | ||
|
|
c6bf1c7f26 | ||
|
|
5f499d66b2 | ||
|
|
7c065bd9fc | ||
|
|
ab849f0942 | ||
|
|
aa1d31bde6 | ||
|
|
5b4dc4dd47 | ||
|
|
1324169ebb | ||
|
|
732afd8393 | ||
|
|
da7b6b11ad | ||
|
|
e260270825 | ||
|
|
d4b6d7646c | ||
|
|
8febab4076 | ||
|
|
34e2c6b943 | ||
|
|
0be8c72601 | ||
|
|
c34e53477f | ||
|
|
8d18190c94 | ||
|
|
06bec61be9 | ||
|
|
2135533f1d | ||
|
|
bb791d59f3 | ||
|
|
30f1c54ed1 | ||
|
|
5c8541ef42 | ||
|
|
fa4b8c1d42 | ||
|
|
7682fe2e45 | ||
|
|
c9b2ce08eb | ||
|
|
246abda46d | ||
|
|
e4bc76c4de | ||
|
|
bdb8383485 | ||
|
|
bb40325977 | ||
|
|
8524cc75d6 | ||
|
|
c1f164c9cb | ||
|
|
4e2d075413 | ||
|
|
f89c200ce9 | ||
|
|
d51dc4fd33 | ||
|
|
00dddb9458 | ||
|
|
1a9301b684 | ||
|
|
80d9b5fca5 | ||
|
|
ac0b7dc8cb | ||
|
|
e586eca16c | ||
|
|
892db25021 | ||
|
|
da75a76d41 | ||
|
|
3ac32fd78a | ||
|
|
3aa657599b | ||
|
|
d4e9087f94 | ||
|
|
da8447a67d | ||
|
|
8e3bcd57a2 | ||
|
|
4572c6c1f8 | ||
|
|
01f2b0ecb7 | ||
|
|
442ba7cbc8 | ||
|
|
6c2b364966 | ||
|
|
0f0c7ec2ed | ||
|
|
2dec016201 | ||
|
|
06125acb8d | ||
|
|
a9b9b3fa0a | ||
|
|
cdf57275b7 | ||
|
|
e5e69b1f75 | ||
|
|
8eca83f3cb | ||
|
|
973316d194 | ||
|
|
a0a6ced148 | ||
|
|
0fc6c477a9 | ||
|
|
401a462398 | ||
|
|
a3839a6ef7 | ||
|
|
8aa4f240c7 | ||
|
|
d9686bae92 | ||
|
|
24e19ae287 | ||
|
|
74fde0ea2c | ||
|
|
890e09b787 | ||
|
|
48098c994d | ||
|
|
64f6343fcc | ||
|
|
24713fbe59 | ||
|
|
7794b744f8 | ||
|
|
0d0c30c16d | ||
|
|
b0364da67c | ||
|
|
6dee89379b | ||
|
|
76db4f801a | ||
|
|
6c2ed4b4f2 | ||
|
|
2541c78dd0 | ||
|
|
97b6e79809 | ||
|
|
6ad3847615 | ||
|
|
a4d830ef83 | ||
|
|
9e540cd5b4 | ||
|
|
3027d8f27e | ||
|
|
e69ec6ab6a | ||
|
|
7ddde41c92 | ||
|
|
7ebe58f20a | ||
|
|
9c2c0e7934 | ||
|
|
c6af1037d9 | ||
|
|
5cb9a126f1 | ||
|
|
f40951cdf5 | ||
|
|
6e264d9de7 | ||
|
|
42db9773f4 | ||
|
|
bb9f6f6d0a | ||
|
|
829ce6573e | ||
|
|
a366d9e208 | ||
|
|
e074c24487 | ||
|
|
54fe05f6d8 | ||
|
|
33a155d9aa | ||
|
|
51878659f8 | ||
|
|
c000c05435 | ||
|
|
b39ffef22c | ||
|
|
d96f882acb | ||
|
|
d409219b51 | ||
|
|
8b619a8224 | ||
|
|
ed075bc9b9 | ||
|
|
8eb098d6fd | ||
|
|
68a8687c80 | ||
|
|
f7d97b02fd | ||
|
|
2691e729cd | ||
|
|
b524a9d49d | ||
|
|
774d8e955c | ||
|
|
c20f98c8b6 | ||
|
|
20ae540fb1 | ||
|
|
58cfa2bb17 | ||
|
|
06005cc10e | ||
|
|
1a3e377304 | ||
|
|
dd29f4c01e | ||
|
|
cb7ecd1cc4 | ||
|
|
b5d8142705 | ||
|
|
f45eb1a1da | ||
|
|
2567006412 | ||
|
|
b92107efc8 | ||
|
|
5d19811331 | ||
|
|
697d41c94e | ||
|
|
75d541f967 | ||
|
|
7dfbb71f7a | ||
|
|
a5d14c92ff | ||
|
|
ce091ab42b | ||
|
|
d2fad1cfd9 | ||
|
|
0b5594f145 | ||
|
|
9beaa91db9 | ||
|
|
c8b4c08139 | ||
|
|
dad5501a44 | ||
|
|
1ced2462c1 | ||
|
|
64adaeb276 | ||
|
|
6e26d03fb8 | ||
|
|
493ddb4fe3 | ||
|
|
75fac258e7 | ||
|
|
bc8ee8fc3c | ||
|
|
3724323f76 | ||
|
|
3ef33874b1 |
15
.devcontainer/Dockerfile
Normal file
15
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,15 @@
|
||||
FROM golang:1.20-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
gettext-base=0.21-4 \
|
||||
iptables=1.8.7-1 \
|
||||
libgl1-mesa-dev=20.3.5-1 \
|
||||
xorg-dev=1:7.7+22 \
|
||||
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& go install -v golang.org/x/tools/gopls@latest
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
20
.devcontainer/devcontainer.json
Normal file
20
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"name": "NetBird",
|
||||
"build": {
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||
"ghcr.io/devcontainers/features/go:1": {
|
||||
"version": "1.20"
|
||||
}
|
||||
},
|
||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||
"capAdd": [
|
||||
"NET_ADMIN",
|
||||
"SYS_ADMIN",
|
||||
"SYS_RESOURCE"
|
||||
],
|
||||
"privileged": true
|
||||
}
|
||||
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.go text eol=lf
|
||||
41
.github/workflows/android-build-validation.yml
vendored
Normal file
41
.github/workflows/android-build-validation.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: Android build validation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android nebtird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
11
.github/workflows/golang-test-darwin.yml
vendored
11
.github/workflows/golang-test-darwin.yml
vendored
@@ -12,17 +12,20 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['jsonfile', 'sqlite']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -33,4 +36,4 @@ jobs:
|
||||
run: go mod tidy
|
||||
|
||||
- name: Test
|
||||
run: go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
|
||||
30
.github/workflows/golang-test-linux.yml
vendored
30
.github/workflows/golang-test-linux.yml
vendored
@@ -15,16 +15,17 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['386','amd64']
|
||||
store: ['jsonfile', 'sqlite']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -32,7 +33,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
@@ -41,19 +42,18 @@ jobs:
|
||||
run: go mod tidy
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -61,10 +61,10 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -82,7 +82,7 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||
@@ -95,15 +95,17 @@ jobs:
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
4
.github/workflows/golang-test-windows.yml
vendored
4
.github/workflows/golang-test-windows.yml
vendored
@@ -39,7 +39,9 @@ jobs:
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
- run: choco install -y sysinternals
|
||||
- run: choco install -y sysinternals --ignore-checksums
|
||||
- run: choco install -y mingw
|
||||
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
|
||||
|
||||
39
.github/workflows/golangci-lint.yml
vendored
39
.github/workflows/golangci-lint.yml
vendored
@@ -1,21 +1,48 @@
|
||||
name: golangci-lint
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
codespell:
|
||||
name: codespell
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||
name: lint
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
args: --timeout=6m
|
||||
version: latest
|
||||
args: --timeout=12m
|
||||
36
.github/workflows/install-script-test.yml
vendored
Normal file
36
.github/workflows/install-script-test.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Test installation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
test-install-script:
|
||||
strategy:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
skip_ui_mode: [true, false]
|
||||
install_binary: [true, false]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
|
||||
USE_BIN_INSTALL: ${{ matrix.install_binary }}
|
||||
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
|
||||
run: |
|
||||
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
|
||||
cat release_files/install.sh | sh -x
|
||||
|
||||
- name: check cli binary
|
||||
run: command -v netbird
|
||||
60
.github/workflows/install-test-darwin.yml
vendored
60
.github/workflows/install-test-darwin.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Test installation Darwin
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
env:
|
||||
SKIP_UI_APP: true
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
install-all:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
||||
echo "Error: NetBird UI is not installed"
|
||||
exit 1
|
||||
fi
|
||||
38
.github/workflows/install-test-linux.yml
vendored
38
.github/workflows/install-test-linux.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: Test installation Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
check_bin_install: [true, false]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename apt package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: |
|
||||
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
||||
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
63
.github/workflows/release.yml
vendored
63
.github/workflows/release.yml
vendored
@@ -7,9 +7,20 @@ on:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.goreleaser.yml'
|
||||
- '.goreleaser_ui.yaml'
|
||||
- '.goreleaser_ui_darwin.yaml'
|
||||
- '.github/workflows/release.yml'
|
||||
- 'release_files/**'
|
||||
- '**/Dockerfile'
|
||||
- '**/Dockerfile.*'
|
||||
- 'client/ui/**'
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.8"
|
||||
SIGN_PIPE_VER: "v0.0.10"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
|
||||
concurrency:
|
||||
@@ -19,20 +30,24 @@ concurrency:
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v1
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -46,10 +61,10 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
-
|
||||
name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
uses: docker/setup-qemu-action@v2
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v2
|
||||
-
|
||||
name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
@@ -72,10 +87,10 @@ jobs:
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --rm-dist
|
||||
args: release --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
@@ -83,7 +98,7 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v2
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
@@ -92,17 +107,19 @@ jobs:
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v1
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -116,23 +133,23 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v2
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -141,19 +158,21 @@ jobs:
|
||||
release_ui_darwin:
|
||||
runs-on: macos-11
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v1
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -165,15 +184,15 @@ jobs:
|
||||
-
|
||||
name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v2
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
name: Test Docker Compose Linux
|
||||
name: Test Infrastructure files
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
test-docker-compose:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
@@ -22,12 +26,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -35,7 +39,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: cp setup.env
|
||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||
@@ -53,6 +57,9 @@ jobs:
|
||||
CI_NETBIRD_MGMT_IDP: "none"
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
- name: check values
|
||||
working-directory: infrastructure_files
|
||||
@@ -68,6 +75,7 @@ jobs:
|
||||
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
|
||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
|
||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||
@@ -76,6 +84,9 @@ jobs:
|
||||
CI_NETBIRD_MGMT_IDP: "none"
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
run: |
|
||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
@@ -87,11 +98,14 @@ jobs:
|
||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
|
||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
|
||||
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
|
||||
grep UseIDToken management.json | grep false
|
||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
@@ -99,6 +113,34 @@ jobs:
|
||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
||||
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: Build management binary
|
||||
working-directory: management
|
||||
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
||||
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
run: CGO_ENABLED=0 go build -o netbird-signal main.go
|
||||
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
@@ -110,6 +152,31 @@ jobs:
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
|
||||
test $count -eq 4
|
||||
working-directory: infrastructure_files
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
- name: test Caddy file gen
|
||||
run: test -f Caddyfile
|
||||
- name: test docker-compose file gen
|
||||
run: test -f docker-compose.yml
|
||||
- name: test management.json file gen
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
run: test -f turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
run: test -f dashboard.env
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -7,8 +7,17 @@ bin/
|
||||
conf.json
|
||||
http-cmds.sh
|
||||
infrastructure_files/management.json
|
||||
infrastructure_files/management-*.json
|
||||
infrastructure_files/docker-compose.yml
|
||||
infrastructure_files/openid-configuration.json
|
||||
infrastructure_files/turnserver.conf
|
||||
management/management
|
||||
client/client
|
||||
client/client.exe
|
||||
*.syso
|
||||
client/.distfiles/
|
||||
infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
*.db
|
||||
123
.golangci.yaml
Normal file
123
.golangci.yaml
Normal file
@@ -0,0 +1,123 @@
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 6m
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: false
|
||||
|
||||
gosec:
|
||||
includes:
|
||||
- G101 # Look for hard coded credentials
|
||||
#- G102 # Bind to all interfaces
|
||||
- G103 # Audit the use of unsafe block
|
||||
- G104 # Audit errors not checked
|
||||
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
|
||||
#- G107 # Url provided to HTTP request as taint input
|
||||
- G108 # Profiling endpoint automatically exposed on /debug/pprof
|
||||
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
|
||||
- G110 # Potential DoS vulnerability via decompression bomb
|
||||
- G111 # Potential directory traversal
|
||||
#- G112 # Potential slowloris attack
|
||||
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
|
||||
#- G114 # Use of net/http serve function that has no support for setting timeouts
|
||||
- G201 # SQL query construction using format string
|
||||
- G202 # SQL query construction using string concatenation
|
||||
- G203 # Use of unescaped data in HTML templates
|
||||
#- G204 # Audit use of command execution
|
||||
- G301 # Poor file permissions used when creating a directory
|
||||
- G302 # Poor file permissions used with chmod
|
||||
- G303 # Creating tempfile using a predictable path
|
||||
- G304 # File path provided as taint input
|
||||
- G305 # File traversal when extracting zip/tar archive
|
||||
- G306 # Poor file permissions used when writing to a new file
|
||||
- G307 # Poor file permissions used when creating a file with os.Create
|
||||
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
|
||||
#- G402 # Look for bad TLS connection settings
|
||||
- G403 # Ensure minimum RSA key length of 2048 bits
|
||||
#- G404 # Insecure random number source (rand)
|
||||
#- G501 # Import blocklist: crypto/md5
|
||||
- G502 # Import blocklist: crypto/des
|
||||
- G503 # Import blocklist: crypto/rc4
|
||||
- G504 # Import blocklist: net/http/cgi
|
||||
#- G505 # Import blocklist: crypto/sha1
|
||||
- G601 # Implicit memory aliasing of items from a range statement
|
||||
- G602 # Slice access out of bounds
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: false
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||
- gosimple # specializes in simplifying a code
|
||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # detects when assignments to existing variables are not used
|
||||
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
||||
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # checks for unused constants, variables, functions and types
|
||||
## disable by default but the have interesting results so lets add them
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- dupword # dupword checks for duplicate words in the source code
|
||||
- durationcheck # durationcheck checks for two durations multiplied together
|
||||
- forbidigo # forbidigo forbids identifiers
|
||||
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
||||
- gosec # inspects source code for security problems
|
||||
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
|
||||
- misspell # misspess finds commonly misspelled English words in comments
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 5
|
||||
|
||||
exclude-rules:
|
||||
# allow fmt
|
||||
- path: management/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: signal/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: sharedsock/filter\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: client/firewall/iptables/rule\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: test\.go
|
||||
linters:
|
||||
- mirror
|
||||
- gosec
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
@@ -377,3 +377,13 @@ uploads:
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
method: PUT
|
||||
|
||||
checksum:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
@@ -11,6 +11,8 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
@@ -52,12 +54,9 @@ nfpms:
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/disconnected.png
|
||||
- src: client/ui/netbird-systemtray-default.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- libayatana-appindicator3-1
|
||||
- libgtk-3-dev
|
||||
- libappindicator3-dev
|
||||
- netbird
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
@@ -72,12 +71,9 @@ nfpms:
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/disconnected.png
|
||||
- src: client/ui/netbird-systemtray-default.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- libayatana-appindicator3-1
|
||||
- libgtk-3-dev
|
||||
- libappindicator3-dev
|
||||
- netbird
|
||||
|
||||
uploads:
|
||||
@@ -95,4 +91,4 @@ uploads:
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
method: PUT
|
||||
method: PUT
|
||||
|
||||
@@ -23,7 +23,6 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
|
||||
- [Test suite](#test-suite)
|
||||
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
|
||||
- [Other project repositories](#other-project-repositories)
|
||||
- [Checklist before submitting a new node](#checklist-before-submitting-a-new-node)
|
||||
- [Contributor License Agreement](#contributor-license-agreement)
|
||||
|
||||
## Code of conduct
|
||||
@@ -70,7 +69,7 @@ dependencies are installed. Here is a short guide on how that can be done.
|
||||
|
||||
### Requirements
|
||||
|
||||
#### Go 1.19
|
||||
#### Go 1.21
|
||||
|
||||
Follow the installation guide from https://go.dev/
|
||||
|
||||
@@ -139,15 +138,14 @@ checked out and set up:
|
||||
### Build and start
|
||||
#### Client
|
||||
|
||||
> Windows clients have a Wireguard driver requirement. We provide a bash script that can be executed in WLS 2 with docker support [wireguard_nt.sh](/client/wireguard_nt.sh).
|
||||
|
||||
To start NetBird, execute:
|
||||
```
|
||||
cd client
|
||||
# bash wireguard_nt.sh # if windows
|
||||
go build .
|
||||
CGO_ENABLED=0 go build .
|
||||
```
|
||||
|
||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`.
|
||||
|
||||
To start NetBird the client in the foreground:
|
||||
|
||||
```
|
||||
@@ -185,6 +183,42 @@ To start NetBird the management service:
|
||||
./management management --log-level debug --log-file console --config ./management.json
|
||||
```
|
||||
|
||||
#### Windows Netbird Installer
|
||||
Create dist directory
|
||||
```shell
|
||||
mkdir -p dist/netbird_windows_amd64
|
||||
```
|
||||
|
||||
UI client
|
||||
```shell
|
||||
CC=x86_64-w64-mingw32-gcc CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -o netbird-ui.exe -ldflags "-s -w -H windowsgui" ./client/ui
|
||||
mv netbird-ui.exe ./dist/netbird_windows_amd64/
|
||||
```
|
||||
|
||||
Client
|
||||
```shell
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o netbird.exe ./client/
|
||||
mv netbird.exe ./dist/netbird_windows_amd64/
|
||||
```
|
||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to `./dist/netbird_windows_amd64/`.
|
||||
|
||||
NSIS compiler
|
||||
- [Windows-nsis]( https://nsis.sourceforge.io/Download)
|
||||
- [MacOS-makensis](https://formulae.brew.sh/formula/makensis#default)
|
||||
- [Linux-makensis](https://manpages.ubuntu.com/manpages/trusty/man1/makensis.1.html)
|
||||
|
||||
NSIS Plugins. Download and move them to the NSIS plugins folder.
|
||||
- [EnVar](https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip)
|
||||
- [ShellExecAsUser](https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z)
|
||||
|
||||
Windows Installer
|
||||
```shell
|
||||
export APPVER=0.0.0.1
|
||||
makensis -V4 client/installer.nsis
|
||||
```
|
||||
|
||||
The installer `netbird-installer.exe` will be created in root directory.
|
||||
|
||||
### Test suite
|
||||
|
||||
The tests can be started via:
|
||||
@@ -215,4 +249,4 @@ NetBird project is composed of 3 main repositories:
|
||||
|
||||
That we do not have any potential problems later it is sadly necessary to sign a [Contributor License Agreement](CONTRIBUTOR_LICENSE_AGREEMENT.md). That can be done literally with the push of a button.
|
||||
|
||||
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.
|
||||
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.
|
||||
|
||||
95
README.md
95
README.md
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! Peer expiration.</strong>
|
||||
<a href="https://github.com/netbirdio/netbird/releases">
|
||||
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
|
||||
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
@@ -24,7 +24,7 @@
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
@@ -36,47 +36,62 @@
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird is an open-source VPN management platform built on top of WireGuard® making it easy to create secure private networks for your organization or home.**
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience.
|
||||
|
||||
**Key features:**
|
||||
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
|
||||
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
|
||||
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
|
||||
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
|
||||
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
|
||||
- \[x] Multiuser support - sharing network between multiple users.
|
||||
- \[x] SSO and MFA support.
|
||||
- \[x] Multicloud and hybrid-cloud support.
|
||||
- \[x] Kernel WireGuard usage when possible.
|
||||
- \[x] Access Controls - groups & rules.
|
||||
- \[x] Remote SSH access without managing SSH keys.
|
||||
- \[x] Network Routes.
|
||||
- \[x] Private DNS.
|
||||
- \[x] Network Activity Monitoring.
|
||||
|
||||
**Coming soon:**
|
||||
- \[ ] Mobile clients.
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||
|
||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
### Key features
|
||||
|
||||
### Start using NetBird
|
||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
||||
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||
| Connectivity | Management | Automation | Platforms |
|
||||
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[ ] iOS </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
||||
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
||||
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
@@ -88,18 +103,18 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
||||
|
||||
### Roadmap
|
||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
@@ -107,7 +122,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
||||
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -18,10 +18,9 @@ func Encode(num uint32) string {
|
||||
}
|
||||
|
||||
var encoded strings.Builder
|
||||
remainder := uint32(0)
|
||||
|
||||
for num > 0 {
|
||||
remainder = num % base
|
||||
remainder := num % base
|
||||
encoded.WriteByte(alphabet[remainder])
|
||||
num /= base
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
FROM gcr.io/distroless/base:debug
|
||||
FROM alpine:3
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
|
||||
SHELL ["/busybox/sh","-c"]
|
||||
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
|
||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||
COPY netbird /go/bin/netbird
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -30,9 +31,14 @@ type IFaceDiscover interface {
|
||||
stdnet.ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
// RouteListener export internal RouteListener for mobile
|
||||
type RouteListener interface {
|
||||
routemanager.RouteListener
|
||||
// NetworkChangeListener export internal NetworkChangeListener for mobile
|
||||
type NetworkChangeListener interface {
|
||||
listener.NetworkChangeListener
|
||||
}
|
||||
|
||||
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||
type DnsReadyListener interface {
|
||||
dns.ReadyListener
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -41,31 +47,31 @@ func init() {
|
||||
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
routeListener routemanager.RouteListener
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client {
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
routeListener: routeListener,
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
networkChangeListener: networkChangeListener,
|
||||
}
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener) error {
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -90,7 +96,31 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||
c.ctxCancelLock.Lock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
defer c.ctxCancel()
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -126,6 +156,17 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConnectionListener set the network connection listener
|
||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.SetConnectionListener(listener)
|
||||
|
||||
26
client/android/dns_list.go
Normal file
26
client/android/dns_list.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
// DNSList is a wrapper of []string
|
||||
type DNSList struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
// Add new DNS address to the collection
|
||||
func (array *DNSList) Add(s string) {
|
||||
array.items = append(array.items, s)
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array *DNSList) Get(i int) (string, error) {
|
||||
if i >= len(array.items) || i < 0 {
|
||||
return "", fmt.Errorf("out of range")
|
||||
}
|
||||
return array.items[i], nil
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array *DNSList) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
24
client/android/dns_list_test.go
Normal file
24
client/android/dns_list_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package android
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDNSList_Get(t *testing.T) {
|
||||
l := DNSList{
|
||||
items: make([]string, 1),
|
||||
}
|
||||
|
||||
_, err := l.Get(0)
|
||||
if err != nil {
|
||||
t.Errorf("invalid error: %s", err)
|
||||
}
|
||||
|
||||
_, err = l.Get(-1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
|
||||
_, err = l.Get(1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
// SSOListener is async listener for mobile framework
|
||||
@@ -85,11 +84,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
@@ -183,35 +192,23 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -57,11 +57,11 @@ func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||
p.SetManagementURL(exampleString)
|
||||
resp, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read managmenet url: %s", err)
|
||||
t.Fatalf("failed to read management url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleString {
|
||||
t.Errorf("unexpected managemenet url: %s", resp)
|
||||
t.Errorf("unexpected management url: %s", resp)
|
||||
}
|
||||
|
||||
p.SetPreSharedKey(exampleString)
|
||||
@@ -102,11 +102,11 @@ func TestPreferences_Commit(t *testing.T) {
|
||||
|
||||
resp, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read managmenet url: %s", err)
|
||||
t.Fatalf("failed to read management url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleURL {
|
||||
t.Errorf("unexpected managemenet url: %s", resp)
|
||||
t.Errorf("unexpected management url: %s", resp)
|
||||
}
|
||||
|
||||
resp, err = p.GetPreSharedKey()
|
||||
|
||||
@@ -3,20 +3,20 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -81,9 +81,11 @@ var loginCmd = &cobra.Command{
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
@@ -113,7 +115,7 @@ var loginCmd = &cobra.Command{
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
@@ -163,40 +165,24 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
mgmtURL := managementURL
|
||||
if mgmtURL == "" {
|
||||
mgmtURL = internal.DefaultManagementURL
|
||||
}
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your servver or use Setup Keys to login", mgmtURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
@@ -206,15 +192,21 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
var codeMsg string
|
||||
if !strings.Contains(verificationURIComplete, userCode) {
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
err := open.Run(verificationURIComplete)
|
||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
||||
if err != nil {
|
||||
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
cmd.Println("")
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
|
||||
@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
log.Print(err)
|
||||
log.Debug(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. " +
|
||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
||||
"Run the status command: \n\n" +
|
||||
" netbird status\n\n" +
|
||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
||||
return nil
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"You can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -109,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
resp, _ := getStatus(ctx, cmd)
|
||||
resp, err := getStatus(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
@@ -120,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
" netbird up \n\n"+
|
||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
||||
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
|
||||
resp.GetStatus(),
|
||||
)
|
||||
return nil
|
||||
@@ -133,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||
|
||||
statusOutputString := ""
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||
@@ -234,7 +234,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
peersConnected = peersConnected + 1
|
||||
peersConnected++
|
||||
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
@@ -407,7 +407,7 @@ func parsePeers(peers peersStateOutput) string {
|
||||
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
|
||||
peersString = peersString + peerString
|
||||
peersString += peerString
|
||||
}
|
||||
return peersString
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
)
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &mgmt.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
@@ -44,6 +45,7 @@ func startTestingServices(t *testing.T) string {
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -60,28 +62,29 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, err := mgmt.NewFileStore(config.Datadir, nil)
|
||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
eventStore, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -98,6 +101,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
func startClientDaemon(
|
||||
t *testing.T, ctx context.Context, managementURL, configPath string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
@@ -123,7 +123,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing dameon gRPC client connection %v", err)
|
||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -141,13 +141,15 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
@@ -178,7 +180,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
@@ -199,11 +201,11 @@ func validateNATExternalIPs(list []string) error {
|
||||
|
||||
subElements := strings.Split(element, "/")
|
||||
if len(subElements) > 2 {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
||||
}
|
||||
|
||||
if len(subElements) == 1 && !isValidIP(subElements[0]) {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
||||
}
|
||||
|
||||
last := 0
|
||||
@@ -258,7 +260,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
var parsed []byte
|
||||
if modified {
|
||||
if !isValidAddrPort(customDNSAddress) {
|
||||
return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||
}
|
||||
if customDNSAddress == "" && logFile != "console" {
|
||||
parsed = []byte("empty")
|
||||
|
||||
@@ -40,6 +40,9 @@ const (
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
@@ -51,6 +54,7 @@ type Manager interface {
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
|
||||
@@ -60,5 +64,8 @@ type Manager interface {
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// TODO: migrate routemanager firewal actions to this interface
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -35,36 +36,53 @@ type Manager struct {
|
||||
inputDefaultRuleSpecs []string
|
||||
outputDefaultRuleSpecs []string
|
||||
wgIface iFaceMapper
|
||||
|
||||
rulesets map[string]ruleset
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
type ruleset struct {
|
||||
rule *Rule
|
||||
ips map[string]string
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
||||
m := &Manager{
|
||||
wgIface: wgIface,
|
||||
inputDefaultRuleSpecs: []string{
|
||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||
outputDefaultRuleSpecs: []string{
|
||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||
rulesets: make(map[string]ruleset),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
m.ipv4Client = ipv4Client
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||
} else {
|
||||
m.ipv6Client = ipv6Client
|
||||
if ipv6Supported {
|
||||
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
|
||||
}
|
||||
}
|
||||
|
||||
if m.ipv4Client == nil && m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
@@ -75,7 +93,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol fw.Protocol,
|
||||
@@ -83,6 +101,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
@@ -101,22 +120,41 @@ func (m *Manager) AddFiltering(
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
|
||||
if ipsetName != "" {
|
||||
rs, rsExists := m.rulesets[ipsetName]
|
||||
if !rsExists {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
if rsExists {
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
rs.ips[ip.String()] = ruleID
|
||||
return &Rule{
|
||||
ruleID: ruleID,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
}
|
||||
// this is new ipset so we need to create firewall rule for it
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs(
|
||||
"filter",
|
||||
ip,
|
||||
string(protocol),
|
||||
sPortVal,
|
||||
dPortVal,
|
||||
direction,
|
||||
action,
|
||||
comment,
|
||||
)
|
||||
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||
@@ -144,12 +182,24 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
}
|
||||
|
||||
return &Rule{
|
||||
id: ruleID,
|
||||
specs: specs,
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
rule := &Rule{
|
||||
ruleID: ruleID,
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
if ipsetName != "" {
|
||||
// ipset name is defined and it means that this rule was created
|
||||
// for it, need to associate it with ruleset
|
||||
m.rulesets[ipsetName] = ruleset{
|
||||
rule: rule,
|
||||
ips: map[string]string{rule.ip: ruleID},
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
@@ -170,6 +220,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
client = m.ipv6Client
|
||||
}
|
||||
|
||||
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := rs.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(rs.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(rs.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
delete(m.rulesets, r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
r = rs.rule
|
||||
}
|
||||
|
||||
if r.dst {
|
||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||
}
|
||||
@@ -193,6 +268,41 @@ func (m *Manager) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if m.wgIface.IsUserspaceBind() {
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
fw.RuleDirectionOUT,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||
@@ -233,13 +343,22 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
for ipsetName := range m.rulesets {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
delete(m.rulesets, ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||
direction fw.RuleDirection, action fw.Action, comment string,
|
||||
ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
@@ -249,11 +368,19 @@ func (m *Manager) filterRuleSpecs(
|
||||
switch direction {
|
||||
case fw.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case fw.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
@@ -265,8 +392,7 @@ func (m *Manager) filterRuleSpecs(
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
specs = append(specs, "-j", m.actionToStr(action))
|
||||
return append(specs, "-m", "comment", "--comment", comment)
|
||||
return append(specs, "-j", m.actionToStr(action))
|
||||
}
|
||||
|
||||
// rawClient returns corresponding iptables client for the given ip
|
||||
@@ -301,7 +427,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -335,3 +461,18 @@ func (m *Manager) actionToStr(action fw.Action) string {
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
switch {
|
||||
case ipsetName == "":
|
||||
return ""
|
||||
case sPort != "" && dPort != "":
|
||||
return ipsetName + "-sport-dport"
|
||||
case sPort != "":
|
||||
return ipsetName + "-sport"
|
||||
case dPort != "":
|
||||
return ipsetName + "-dport"
|
||||
default:
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
@@ -53,14 +55,15 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
err := manager.Reset()
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
@@ -68,7 +71,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
@@ -81,33 +84,31 @@ func TestIptablesManager(t *testing.T) {
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule1); err != nil {
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule2); err != nil {
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
@@ -122,7 +123,90 @@ func TestIptablesManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||
t.Helper()
|
||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||
require.NoError(t, err, "failed to check rule")
|
||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
||||
@@ -148,14 +232,14 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
err := manager.Reset()
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
@@ -167,9 +251,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -2,13 +2,16 @@ package iptables
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
id string
|
||||
ruleID string
|
||||
ipsetName string
|
||||
|
||||
specs []string
|
||||
ip string
|
||||
dst bool
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -27,13 +29,18 @@ const (
|
||||
|
||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
|
||||
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn *nftables.Conn
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
|
||||
@@ -43,6 +50,10 @@ type Manager struct {
|
||||
filterInputChainIPv6 *nftables.Chain
|
||||
filterOutputChainIPv6 *nftables.Chain
|
||||
|
||||
rulesetManager *rulesetManager
|
||||
setRemovedIPs map[string]struct{}
|
||||
setRemoved map[string]*nftables.Set
|
||||
|
||||
wgIface iFaceMapper
|
||||
}
|
||||
|
||||
@@ -54,8 +65,23 @@ type iFaceMapper interface {
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
conn: &nftables.Conn{},
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
|
||||
rulesetManager: newRuleManager(),
|
||||
setRemovedIPs: map[string]struct{}{},
|
||||
setRemoved: map[string]*nftables.Set{},
|
||||
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
@@ -77,6 +103,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
@@ -84,6 +111,7 @@ func (m *Manager) AddFiltering(
|
||||
|
||||
var (
|
||||
err error
|
||||
ipset *nftables.Set
|
||||
table *nftables.Table
|
||||
chain *nftables.Chain
|
||||
)
|
||||
@@ -107,6 +135,46 @@ func (m *Manager) AddFiltering(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
rawIP = ip.To16()
|
||||
}
|
||||
|
||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||
|
||||
if ipsetName != "" {
|
||||
// if we already have set with given name, just add ip to the set
|
||||
// and return rule with new ID in other case let's create rule
|
||||
// with fresh created set and set element
|
||||
|
||||
var isSetNew bool
|
||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
isSetNew = true
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
if !isSetNew {
|
||||
// if we already have nftables rules with set for given direction
|
||||
// just add new rule to the ruleset and return new fw.Rule object
|
||||
|
||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
// if ipset exists but it is not linked to rule for given direction
|
||||
// create new rule for direction and bind ipset to it later
|
||||
}
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
@@ -146,39 +214,47 @@ func (m *Manager) AddFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
var adrLen, adrOffset uint32
|
||||
if ip.To4() == nil {
|
||||
adrLen = 16
|
||||
adrOffset = 8
|
||||
} else {
|
||||
adrLen = 4
|
||||
adrOffset = 12
|
||||
addrLen := uint32(len(rawIP))
|
||||
addrOffset := uint32(12)
|
||||
if addrLen == 16 {
|
||||
addrOffset = 8
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
adrOffset += adrLen
|
||||
addrOffset += addrLen
|
||||
}
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: adrOffset,
|
||||
Len: adrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Offset: addrOffset,
|
||||
Len: addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipsetName,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
@@ -219,39 +295,76 @@ func (m *Manager) AddFiltering(
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||
|
||||
_ = m.conn.InsertRule(&nftables.Rule{
|
||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
list, err := m.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
|
||||
// getRulesetID returns ruleset ID based on given parameters
|
||||
func (m *Manager) getRulesetID(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipsetName == "" {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipsetName + rulesetID
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *Manager) createSet(
|
||||
table *nftables.Table,
|
||||
rawIP []byte,
|
||||
name string,
|
||||
) (*nftables.Set, error) {
|
||||
keyType := nftables.TypeIPAddr
|
||||
if len(rawIP) == 16 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
// else we create new ipset and continue creating rule
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: keyType,
|
||||
}
|
||||
|
||||
// Add the rule to the chain
|
||||
rule := &Rule{id: id}
|
||||
for _, r := range list {
|
||||
if bytes.Equal(r.UserData, userData) {
|
||||
rule.Rule = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if rule.Rule == nil {
|
||||
return nil, fmt.Errorf("rule not found")
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// chain returns the chain for the given IP address with specific settings
|
||||
@@ -268,7 +381,7 @@ func (m *Manager) chain(
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
||||
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
@@ -288,13 +401,20 @@ func (m *Manager) chain(
|
||||
}
|
||||
|
||||
// table returns the table for the given family of the IP address
|
||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
func (m *Manager) table(
|
||||
family nftables.TableFamily, tableName string,
|
||||
) (*nftables.Table, error) {
|
||||
// we cache access to Netbird ACL table only
|
||||
if tableName != FilterTableName {
|
||||
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||
}
|
||||
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
if m.tableIPv4 != nil {
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -306,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -314,34 +434,41 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.conn.ListTablesOfFamily(family)
|
||||
func (m *Manager) createTableIfNotExists(
|
||||
family nftables.TableFamily, tableName string,
|
||||
) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
if t.Name == tableName {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
family nftables.TableFamily,
|
||||
tableName string,
|
||||
name string,
|
||||
hooknum nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
chainType nftables.ChainType,
|
||||
) (*nftables.Chain, error) {
|
||||
table, err := m.table(family)
|
||||
table, err := m.table(family, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
@@ -362,7 +489,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.conn.AddChain(chain)
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
shiftDSTAddr := 0
|
||||
@@ -429,7 +556,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
)
|
||||
}
|
||||
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
@@ -444,12 +571,13 @@ func (m *Manager) createChainIfNotExists(
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -458,16 +586,58 @@ func (m *Manager) createChainIfNotExists(
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
nativeRule, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
||||
return err
|
||||
if nativeRule.nftRule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
if nativeRule.nftSet != nil {
|
||||
// call twice of delete set element raises error
|
||||
// so we need to check if element is already removed
|
||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.setRemovedIPs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if m.rulesetManager.deleteRule(nativeRule) {
|
||||
// deleteRule indicates that we still have IP in the ruleset
|
||||
// it means we should not remove the nftables rule but need to update set
|
||||
// so we prepare IP to be removed from set on the next flush call
|
||||
return nil
|
||||
}
|
||||
|
||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
nativeRule.nftRule = nil
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||
}
|
||||
nativeRule.nftSet = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -475,27 +645,217 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.conn.ListChains()
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.conn.DelChain(c)
|
||||
m.rConn.DelChain(c)
|
||||
}
|
||||
}
|
||||
|
||||
tables, err := m.conn.ListTables()
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
m.conn.DelTable(t)
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set must be removed after flush rule changes
|
||||
// otherwise we will get error
|
||||
for _, s := range m.setRemoved {
|
||||
m.rConn.FlushSet(s)
|
||||
m.rConn.DelSet(s)
|
||||
}
|
||||
|
||||
if len(m.setRemoved) > 0 {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.setRemovedIPs = map[string]struct{}{}
|
||||
m.setRemoved = map[string]*nftables.Set{}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
tf := nftables.TableFamilyIPv4
|
||||
if m.wgIface.Address().IP.To4() == nil {
|
||||
tf = nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(tf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chain == nil {
|
||||
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||
}
|
||||
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
m.applyAllowNetbirdRules(chain)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||
if table == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) != 0 {
|
||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||
log.Errorf("failed to set rule handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(AllowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
}
|
||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||
continue
|
||||
}
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
// test expectations:
|
||||
// 1) regular rule
|
||||
// 2) "accept extra routed traffic rule" for the interface
|
||||
@@ -135,12 +140,15 @@ func TestNftablesManager(t *testing.T) {
|
||||
err = manager.DeleteRule(rule)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// test expectations:
|
||||
// 1) "accept extra routed traffic rule" for the interface
|
||||
// 2) "drop all rule" for the interface
|
||||
require.Len(t, rules, 2, "expected 2 rules after deleteion")
|
||||
require.Len(t, rules, 2, "expected 2 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
*nftables.Rule
|
||||
id string
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
|
||||
ruleID string
|
||||
ip []byte
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||
type nftRuleset struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
issuedRules map[string]*Rule
|
||||
rulesetID string
|
||||
}
|
||||
|
||||
type rulesetManager struct {
|
||||
rulesets map[string]*nftRuleset
|
||||
|
||||
nftSetName2rulesetID map[string]string
|
||||
issuedRuleID2rulesetID map[string]string
|
||||
}
|
||||
|
||||
func newRuleManager() *rulesetManager {
|
||||
return &rulesetManager{
|
||||
rulesets: map[string]*nftRuleset{},
|
||||
|
||||
nftSetName2rulesetID: map[string]string{},
|
||||
issuedRuleID2rulesetID: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||
ruleset, ok := r.rulesets[rulesetID]
|
||||
return ruleset, ok
|
||||
}
|
||||
|
||||
func (r *rulesetManager) createRuleset(
|
||||
rulesetID string,
|
||||
nftRule *nftables.Rule,
|
||||
nftSet *nftables.Set,
|
||||
) *nftRuleset {
|
||||
ruleset := nftRuleset{
|
||||
rulesetID: rulesetID,
|
||||
nftRule: nftRule,
|
||||
nftSet: nftSet,
|
||||
issuedRules: map[string]*Rule{},
|
||||
}
|
||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||
if nftSet != nil {
|
||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||
}
|
||||
return &ruleset
|
||||
}
|
||||
|
||||
func (r *rulesetManager) addRule(
|
||||
ruleset *nftRuleset,
|
||||
ip []byte,
|
||||
) (*Rule, error) {
|
||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||
return nil, fmt.Errorf("ruleset not found")
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
nftRule: ruleset.nftRule,
|
||||
nftSet: ruleset.nftSet,
|
||||
ruleID: xid.New().String(),
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
ruleset.issuedRules[rule.ruleID] = &rule
|
||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// deleteRule from ruleset and returns true if contains other rules
|
||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ruleset := r.rulesets[rulesetID]
|
||||
if ruleset.nftRule == nil {
|
||||
return false
|
||||
}
|
||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||
delete(ruleset.issuedRules, rule.ruleID)
|
||||
|
||||
if len(ruleset.issuedRules) == 0 {
|
||||
delete(r.rulesets, ruleset.rulesetID)
|
||||
if rule.nftSet != nil {
|
||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||
//
|
||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||
// we set correct handle value to it.
|
||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||
ruleset, ok := r.rulesets[string(split[0])]
|
||||
if !ok {
|
||||
return fmt.Errorf("ruleset not found")
|
||||
}
|
||||
*ruleset.nftRule = *nftRule
|
||||
return nil
|
||||
}
|
||||
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{
|
||||
UserData: []byte(rulesetID),
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_addRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||
|
||||
ruleset2 := &nftRuleset{
|
||||
rulesetID: "ruleset-2",
|
||||
}
|
||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||
require.Error(t, err, "addRule() should have failed")
|
||||
}
|
||||
|
||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
ip2 := []byte("192.168.1.1")
|
||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule2, "rule should not be nil")
|
||||
|
||||
hasNext := rulesetManager.deleteRule(rule)
|
||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||
|
||||
// Check that the rule is no longer in the manager.
|
||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||
|
||||
hasNext = rulesetManager.deleteRule(rule2)
|
||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||
}
|
||||
|
||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.0.1")
|
||||
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
nftRuleCopy := nftRule
|
||||
nftRuleCopy.Handle = 2
|
||||
nftRuleCopy.UserData = []byte(rulesetID)
|
||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||
// check correct work with references
|
||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
nftSet := nftables.Set{
|
||||
ID: 2,
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
|
||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||
require.True(t, ok, "getRuleset() failed")
|
||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||
|
||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||
require.False(t, ok, "getRuleset() failed")
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
|
||||
@@ -28,3 +32,15 @@ type Port struct {
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
}
|
||||
|
||||
// String interface implementation
|
||||
func (p *Port) String() string {
|
||||
var ports string
|
||||
for _, port := range p.Values {
|
||||
if ports != "" {
|
||||
ports += ","
|
||||
}
|
||||
ports += strconv.Itoa(port)
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
19
client/firewall/uspfilter/allow_netbird.go
Normal file
19
client/firewall/uspfilter/allow_netbird.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !windows && !linux
|
||||
|
||||
package uspfilter
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package uspfilter
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.resetHook != nil {
|
||||
return m.resetHook()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
94
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
94
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type action string
|
||||
|
||||
const (
|
||||
addRule action = "add"
|
||||
deleteRule action = "delete"
|
||||
firewallRuleName = "Netbird"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
}
|
||||
return manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
)
|
||||
}
|
||||
|
||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||
|
||||
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||
if action == addRule {
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func isWindowsFirewallReachable() bool {
|
||||
args := []string{"advfirewall", "show", "allprofiles", "state"}
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Infof("Windows firewall is not reachable, skipping default rule management. Using only user space rules. Error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isFirewallRuleActive(ruleName string) bool {
|
||||
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
_, err := cmd.Output()
|
||||
return err == nil
|
||||
}
|
||||
@@ -19,15 +19,20 @@ const layerTypeAll = 0
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules []Rule
|
||||
incomingRules []Rule
|
||||
rulesIndex map[string]int
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
resetHook func() error
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
@@ -48,7 +53,6 @@ type decoder struct {
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rulesIndex: make(map[string]int),
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -62,6 +66,9 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return d
|
||||
},
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@@ -81,6 +88,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
r := Rule{
|
||||
@@ -124,15 +132,17 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var p int
|
||||
if direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules, r)
|
||||
p = len(m.incomingRules) - 1
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules, r)
|
||||
p = len(m.outgoingRules) - 1
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
m.rulesIndex[r.id] = p
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &r, nil
|
||||
@@ -148,38 +158,25 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
p, ok := m.rulesIndex[r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.rulesIndex, r.id)
|
||||
|
||||
var toUpdate []Rule
|
||||
if r.direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
||||
toUpdate = m.incomingRules
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
||||
toUpdate = m.outgoingRules
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
}
|
||||
|
||||
for i := 0; i < len(toUpdate); i++ {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = m.outgoingRules[:0]
|
||||
m.incomingRules = m.incomingRules[:0]
|
||||
m.rulesIndex = make(map[string]int)
|
||||
|
||||
return nil
|
||||
}
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
@@ -191,8 +188,8 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules, true)
|
||||
}
|
||||
|
||||
// dropFilter imlements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
||||
// dropFilter implements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@@ -224,37 +221,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
// check if IP address match by IP
|
||||
var ip net.IP
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip4.SrcIP
|
||||
} else {
|
||||
ip = d.ip4.DstIP
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip6.SrcIP
|
||||
} else {
|
||||
ip = d.ip6.DstIP
|
||||
}
|
||||
}
|
||||
|
||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP {
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
@@ -264,38 +273,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.udpHook(packetData)
|
||||
return rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
return false, false
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
@@ -325,19 +332,19 @@ func (m *Manager) AddUDPPacketHook(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var toUpdate []Rule
|
||||
if in {
|
||||
r.direction = fw.RuleDirectionIN
|
||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
||||
toUpdate = m.incomingRules
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
||||
toUpdate = m.outgoingRules
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
|
||||
for i := range toUpdate {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@@ -345,15 +352,26 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range m.incomingRules {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, r := range m.outgoingRules {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetResetHook which will be executed in the end of Reset method
|
||||
func (m *Manager) SetResetHook(hook func() error) {
|
||||
m.resetHook = hook
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(iface.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc == nil {
|
||||
return iface.WGAddress{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
@@ -63,7 +71,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -98,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -111,7 +119,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -123,8 +131,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
||||
t.Errorf("rule2 is not in the rulesIndex")
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule2)
|
||||
@@ -133,8 +141,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rule1 still in the rulesIndex")
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,26 +177,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &Manager{
|
||||
incomingRules: []Rule{},
|
||||
outgoingRules: []Rule{},
|
||||
rulesIndex: make(map[string]int),
|
||||
incomingRules: map[string]RuleSet{},
|
||||
outgoingRules: map[string]RuleSet{},
|
||||
}
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules) != 1 {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
addedRule = manager.incomingRules[0]
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
addedRule = manager.outgoingRules[0]
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
@@ -211,17 +222,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure rulesIndex is correctly updated
|
||||
index, ok := manager.rulesIndex[addedRule.id]
|
||||
if !ok {
|
||||
t.Errorf("expected rule to be in rulesIndex")
|
||||
return
|
||||
}
|
||||
if index != 0 {
|
||||
t.Errorf("expected rule index to be 0, got %d", index)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -244,7 +244,7 @@ func TestManagerReset(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -256,7 +256,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rules is not empty")
|
||||
}
|
||||
}
|
||||
@@ -282,7 +282,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -346,10 +346,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, rule := range manager.outgoingRules {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,9 +366,11 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, rule := range manager.outgoingRules {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -394,9 +398,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -166,10 +166,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
EnVar::SetHKLM
|
||||
EnVar::AddValueEx "path" "$INSTDIR"
|
||||
|
||||
SetShellVarContext current
|
||||
SetShellVarContext all
|
||||
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
SetShellVarContext all
|
||||
SectionEnd
|
||||
|
||||
Section -Post
|
||||
@@ -196,10 +195,9 @@ Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
SetShellVarContext current
|
||||
SetShellVarContext all
|
||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||
SetShellVarContext all
|
||||
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
@@ -209,8 +207,7 @@ SectionEnd
|
||||
|
||||
|
||||
Function LaunchLink
|
||||
SetShellVarContext current
|
||||
SetShellVarContext all
|
||||
SetOutPath $INSTDIR
|
||||
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
||||
SetShellVarContext all
|
||||
FunctionEnd
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -30,23 +33,53 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
manager firewall.Manager
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
manager firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type ipsetInfo struct {
|
||||
name string
|
||||
ipCount int
|
||||
}
|
||||
|
||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
//
|
||||
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
|
||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
for _, pairs := range d.rulesPairs {
|
||||
total += len(pairs)
|
||||
}
|
||||
log.Infof(
|
||||
"ACL rules processed in: %v, total rules count: %d",
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
if d.manager == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.manager.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := (networkMap.PeerConfig != nil &&
|
||||
@@ -94,14 +127,37 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
|
||||
applyFailed := false
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||
|
||||
// calculate which IP's can be grouped in by which ipset
|
||||
// to do that we use rule selector (which is just rule properties without IP's)
|
||||
for _, r := range rules {
|
||||
rulePair, err := d.protoRuleToFirewallRule(r)
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipset, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
ipset = &ipsetInfo{}
|
||||
}
|
||||
|
||||
ipset.ipCount++
|
||||
ipsetByRuleSelectors[selector] = ipset
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||
if ipset.name == "" {
|
||||
d.ipsetCounter++
|
||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
}
|
||||
ipsetName := ipset.name
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
applyFailed = true
|
||||
break
|
||||
}
|
||||
newRulePairs[rulePair[0].GetRuleID()] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
if applyFailed {
|
||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
||||
@@ -140,55 +196,71 @@ func (d *DefaultManager) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (string, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
|
||||
protocol := convertToFirewallProtocol(r.Protocol)
|
||||
if protocol == firewall.ProtocolUnknown {
|
||||
return nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||
}
|
||||
|
||||
action := convertFirewallAction(r.Action)
|
||||
if action == firewall.ActionUnknown {
|
||||
return nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
}
|
||||
port = &firewall.Port{
|
||||
Values: []int{value},
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
var err error
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
d.rulesPairs[rules[0].GetRuleID()] = rules
|
||||
return rules, nil
|
||||
d.rulesPairs[ruleID] = rules
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
||||
rule, err := d.manager.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -198,7 +270,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
||||
rule, err = d.manager.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -206,9 +279,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return append(rules, rule), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) addOutRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
||||
rule, err := d.manager.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -218,7 +299,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
||||
rule, err = d.manager.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -226,6 +308,23 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return append(rules, rule), nil
|
||||
}
|
||||
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getRuleID(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) string {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
||||
//
|
||||
@@ -235,7 +334,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range networkMap.RemotePeers {
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
totalIPs++
|
||||
}
|
||||
@@ -246,6 +345,10 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
in := protoMatch{}
|
||||
out := protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
// But we zeroed the IP's for protocol if:
|
||||
@@ -262,12 +365,22 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
}
|
||||
match := protocols[r.Protocol]
|
||||
|
||||
if _, ok := match[r.PeerIP]; ok {
|
||||
// special case, when we receive this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
}
|
||||
match[r.PeerIP] = i
|
||||
|
||||
ipset := protocols[r.Protocol]
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
}
|
||||
ipset[r.PeerIP] = i
|
||||
}
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
@@ -280,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for ther first element ALL, it must be done first
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
||||
mgmProto.FirewallRule_ALL,
|
||||
mgmProto.FirewallRule_ICMP,
|
||||
@@ -288,9 +401,6 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
mgmProto.FirewallRule_UDP,
|
||||
}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||
@@ -346,6 +456,11 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
return append(rules, squashedRules...), squashedProtocols
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||
switch protocol {
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || android
|
||||
|
||||
package acl
|
||||
|
||||
@@ -6,7 +6,8 @@ import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
@@ -18,10 +19,10 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}, nil
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !android
|
||||
|
||||
package acl
|
||||
|
||||
import (
|
||||
@@ -7,30 +9,69 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||
)
|
||||
|
||||
// Create creates a firewall manager instance for the Linux
|
||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
func Create(iface IFaceMapper) (*DefaultManager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
if fm, err = uspfilter.Create(iface); err != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||
return nil, err
|
||||
var err error
|
||||
|
||||
checkResult := checkfw.Check()
|
||||
switch checkResult {
|
||||
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||
log.Debug("creating an iptables firewall manager for access control")
|
||||
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
|
||||
log.Infof("failed to create iptables manager for access control: %s", err)
|
||||
}
|
||||
} else {
|
||||
case checkfw.NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager for access control")
|
||||
if fm, err = nftables.Create(iface); err != nil {
|
||||
log.Debugf("failed to create nftables manager: %s", err)
|
||||
// fallback to iptables
|
||||
if fm, err = iptables.Create(iface); err != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("failed to create nftables manager for access control: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}, nil
|
||||
var resetHookForUserspace func() error
|
||||
if fm != nil && err == nil {
|
||||
// err shadowing is used here, to ignore this error
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
resetHookForUserspace = fm.Reset
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
usfm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set kernel space firewall Reset as hook for userspace firewall
|
||||
// manager Reset method, to clean up
|
||||
if resetHookForUserspace != nil {
|
||||
usfm.SetResetHook(resetHookForUserspace)
|
||||
}
|
||||
|
||||
// to be consistent for any future extensions.
|
||||
// ignore this error
|
||||
if err := usfm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
fm = usfm
|
||||
}
|
||||
|
||||
if fm == nil || err != nil {
|
||||
log.Errorf("failed to create firewall manager: %s", err)
|
||||
// no firewall manager found or initialized correctly
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFilter(gomock.Any())
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
acl, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
return
|
||||
@@ -55,6 +66,11 @@ func TestDefaultManager(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
for id := range acl.rulesPairs {
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[1:]
|
||||
networkMap.FirewallRules = append(
|
||||
@@ -67,11 +83,6 @@ func TestDefaultManager(t *testing.T) {
|
||||
},
|
||||
)
|
||||
|
||||
existedRulesID := map[string]struct{}{}
|
||||
for id := range acl.rulesPairs {
|
||||
existedRulesID[id] = struct{}{}
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
@@ -80,13 +91,16 @@ func TestDefaultManager(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// check that old rules was removed
|
||||
for id := range existedRulesID {
|
||||
if _, ok := acl.rulesPairs[id]; ok {
|
||||
t.Errorf("old rule was not removed")
|
||||
return
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.rulesPairs {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
if previousCount != 1 {
|
||||
t.Errorf("old rule was not removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handle default rules", func(t *testing.T) {
|
||||
@@ -175,31 +189,33 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
}
|
||||
|
||||
r := rules[0]
|
||||
if r.PeerIP != "0.0.0.0" {
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
} else if r.Direction != mgmProto.FirewallRule_IN {
|
||||
case r.Direction != mgmProto.FirewallRule_IN:
|
||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||
return
|
||||
} else if r.Protocol != mgmProto.FirewallRule_ALL {
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
|
||||
r = rules[1]
|
||||
if r.PeerIP != "0.0.0.0" {
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
} else if r.Direction != mgmProto.FirewallRule_OUT {
|
||||
case r.Direction != mgmProto.FirewallRule_OUT:
|
||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||
return
|
||||
} else if r.Protocol != mgmProto.FirewallRule_ALL {
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
@@ -267,7 +283,7 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
|
||||
manager := &DefaultManager{}
|
||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
||||
t.Errorf("we should got same amount of rules as intput, got %v", len(rules))
|
||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,13 +324,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFilter(gomock.Any())
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
acl, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
return
|
||||
|
||||
203
client/internal/auth/device_flow.go
Normal file
203
client/internal/auth/device_flow.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
const (
|
||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
type RequestDeviceCodePayload struct {
|
||||
Audience string `json:"audience"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
type TokenRequestPayload struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// TokenRequestResponse used for parsing Hosted token's response
|
||||
type TokenRequestResponse struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
TokenInfo
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
return &DeviceAuthorizationFlow{
|
||||
providerConfig: config,
|
||||
HTTPClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
form.Add("audience", d.providerConfig.Audience)
|
||||
form.Add("scope", d.providerConfig.Scope)
|
||||
req, err := http.NewRequest("POST", d.providerConfig.DeviceAuthEndpoint,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := d.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return AuthFlowInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
||||
}
|
||||
|
||||
deviceCode := AuthFlowInfo{}
|
||||
err = json.Unmarshal(body, &deviceCode)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
||||
}
|
||||
|
||||
// Fallback to the verification_uri if the IdP doesn't support verification_uri_complete
|
||||
if deviceCode.VerificationURIComplete == "" {
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", info.DeviceCode)
|
||||
|
||||
req, err := http.NewRequest("POST", d.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := d.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := res.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode > 499 {
|
||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
if tokenResponse.Error != "" {
|
||||
if tokenResponse.Error == "authorization_pending" {
|
||||
continue
|
||||
} else if tokenResponse.Error == "slow_down" {
|
||||
interval += (3 * time.Second)
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
UseIDToken: d.providerConfig.UseIDToken,
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), d.providerConfig.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
package internal
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -53,7 +53,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingErrFunc require.ErrorAssertionFunc
|
||||
expectedErrorMSG string
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedOut DeviceAuthInfo
|
||||
expectedOut AuthFlowInfo
|
||||
expectedMSG string
|
||||
expectPayload string
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: expectPayload,
|
||||
}
|
||||
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
||||
testCase4Out := AuthFlowInfo{ExpiresIn: 10}
|
||||
testCase4 := test{
|
||||
name: "Got Device Code",
|
||||
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
||||
@@ -113,8 +113,8 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
providerConfig: ProviderConfig{
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
@@ -125,7 +125,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
||||
@@ -145,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
inputMaxReqs int
|
||||
inputCountResBody string
|
||||
inputTimeout time.Duration
|
||||
inputInfo DeviceAuthInfo
|
||||
inputInfo AuthFlowInfo
|
||||
inputAudience string
|
||||
testingErrFunc require.ErrorAssertionFunc
|
||||
expectedErrorMSG string
|
||||
@@ -155,7 +155,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
expectPayload string
|
||||
}
|
||||
|
||||
defaultInfo := DeviceAuthInfo{
|
||||
defaultInfo := AuthFlowInfo{
|
||||
DeviceCode: "test",
|
||||
ExpiresIn: 10,
|
||||
Interval: 1,
|
||||
@@ -278,8 +278,8 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
providerConfig: ProviderConfig{
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
@@ -287,11 +287,12 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient}
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
||||
109
client/internal/auth/oauth.go
Normal file
109
client/internal/auth/oauth.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// OAuthFlow represents an interface for authorization using different OAuth 2.0 flows
|
||||
type OAuthFlow interface {
|
||||
RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error)
|
||||
WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error)
|
||||
GetClientID(ctx context.Context) string
|
||||
}
|
||||
|
||||
// HTTPClient http client interface for API calls
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||
type AuthFlowInfo struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// Claims used when validating the access token
|
||||
type Claims struct {
|
||||
Audience interface{} `json:"aud"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
UseIDToken bool `json:"-"`
|
||||
}
|
||||
|
||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||
func (t TokenInfo) GetTokenToUse() string {
|
||||
if t.UseIDToken {
|
||||
return t.IDToken
|
||||
}
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||
//
|
||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"Please proceed with setting up this device using setup keys " +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
case ok && s.Code() == codes.Unimplemented:
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||
default:
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
255
client/internal/auth/pkce_flow.go
Normal file
255
client/internal/auth/pkce_flow.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||
|
||||
const (
|
||||
queryState = "state"
|
||||
queryCode = "code"
|
||||
queryError = "error"
|
||||
queryErrorDesc = "error_description"
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
// find the first available redirect URL
|
||||
for _, redirectURL := range config.RedirectURLs {
|
||||
if !isRedirectURLPortUsed(redirectURL) {
|
||||
availableRedirectURL = redirectURL
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if availableRedirectURL == "" {
|
||||
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
|
||||
}
|
||||
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthorizationEndpoint,
|
||||
TokenURL: config.TokenEndpoint,
|
||||
},
|
||||
RedirectURL: availableRedirectURL,
|
||||
Scopes: strings.Split(config.Scope, " "),
|
||||
}
|
||||
|
||||
return &PKCEAuthorizationFlow{
|
||||
providerConfig: config,
|
||||
oAuthConfig: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
||||
return p.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a authorization code login flow information.
|
||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
state, err := randomBytesInHex(24)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||
}
|
||||
p.state = state
|
||||
|
||||
codeVerifier, err := randomBytesInHex(64)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
|
||||
}
|
||||
p.codeVerifier = codeVerifier
|
||||
|
||||
codeChallenge := createCodeChallenge(codeVerifier)
|
||||
authURL := p.oAuthConfig.AuthCodeURL(
|
||||
state,
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
)
|
||||
|
||||
return AuthFlowInfo{
|
||||
VerificationURIComplete: authURL,
|
||||
ExpiresIn: defaultPKCETimeoutSeconds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Errorf("failed to close the server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
return TokenInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
token, err := p.handleRequest(req)
|
||||
if err != nil {
|
||||
renderPKCEFlowTmpl(w, err)
|
||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
renderPKCEFlowTmpl(w, nil)
|
||||
tokenChan <- token
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errChan <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on the state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
req.Context(),
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
TokenType: token.TokenType,
|
||||
ExpiresIn: token.Expiry.Second(),
|
||||
UseIDToken: p.providerConfig.UseIDToken,
|
||||
}
|
||||
if idToken, ok := token.Extra("id_token").(string); ok {
|
||||
tokenInfo.IDToken = idToken
|
||||
}
|
||||
|
||||
// if a provider doesn't support an audience, use the Client ID for token verification
|
||||
audience := p.providerConfig.Audience
|
||||
if audience == "" {
|
||||
audience = p.providerConfig.ClientID
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func createCodeChallenge(codeVerifier string) string {
|
||||
sha2 := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||
}
|
||||
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse redirect URL: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("error while closing the connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := make(map[string]string)
|
||||
if authError != nil {
|
||||
data["Error"] = authError.Error()
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(w, data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
60
client/internal/auth/util.go
Normal file
60
client/internal/auth/util.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func randomBytesInHex(count int) (string, error) {
|
||||
buf := make([]byte, count)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate %d random bytes: %v", count, err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// isValidAccessToken is a simple validation of the access token
|
||||
func isValidAccessToken(token string, audience string) error {
|
||||
if token == "" {
|
||||
return fmt.Errorf("token received is empty")
|
||||
}
|
||||
|
||||
encodedClaims := strings.Split(token, ".")[1]
|
||||
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claims := Claims{}
|
||||
err = json.Unmarshal(claimsString, &claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if claims.Audience == nil {
|
||||
return fmt.Errorf("required token field audience is absent")
|
||||
}
|
||||
|
||||
// Audience claim of JWT can be a string or an array of strings
|
||||
switch aud := claims.Audience.(type) {
|
||||
case string:
|
||||
if aud == audience {
|
||||
return nil
|
||||
}
|
||||
case []interface{}:
|
||||
for _, audItem := range aud {
|
||||
if audStr, ok := audItem.(string); ok && audStr == audience {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid JWT token audience field")
|
||||
}
|
||||
3
client/internal/checkfw/check.go
Normal file
3
client/internal/checkfw/check.go
Normal file
@@ -0,0 +1,3 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package checkfw
|
||||
56
client/internal/checkfw/check_linux.go
Normal file
56
client/internal/checkfw/check_linux.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build !android
|
||||
|
||||
package checkfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN FWType = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
|
||||
IPTABLESWITHV6
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func Check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err == nil {
|
||||
if isIptablesClientAvailable(ip) {
|
||||
ipSupport := IPTABLES
|
||||
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if ip6Err == nil {
|
||||
if isIptablesClientAvailable(ipv6) {
|
||||
ipSupport = IPTABLESWITHV6
|
||||
}
|
||||
}
|
||||
return ipSupport
|
||||
}
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
@@ -215,10 +215,12 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
if *input.PreSharedKey != "" {
|
||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
@@ -271,9 +273,9 @@ func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
if parsedMgmtURL.Port() == "" {
|
||||
switch parsedMgmtURL.Scheme {
|
||||
case "https":
|
||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
|
||||
parsedMgmtURL.Host += ":443"
|
||||
case "http":
|
||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
|
||||
parsedMgmtURL.Host += ":80"
|
||||
default:
|
||||
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
||||
}
|
||||
|
||||
@@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
||||
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
managementURL := "https://test.management.url:33071"
|
||||
adminURL := "https://app.admin.url:443"
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
@@ -63,7 +60,22 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: new empty pre-shared key config -> fetch it
|
||||
newPreSharedKey := ""
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &newPreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 5: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
|
||||
@@ -12,8 +12,9 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -21,10 +22,30 @@ import (
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||
}
|
||||
|
||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -78,7 +99,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
@@ -151,14 +172,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
// in case of non Android os these variables will be nil
|
||||
md := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
RouteListener: routeListener,
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
@@ -168,8 +182,6 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
|
||||
<-engineCtx.Done()
|
||||
statusRecorder.ClientTeardown()
|
||||
|
||||
@@ -190,6 +202,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return nil
|
||||
}
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig ProviderConfig
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type ProviderConfig struct {
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
@@ -88,7 +88,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: ProviderConfig{
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
@@ -105,7 +105,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isProviderConfigValid(config ProviderConfig) error {
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
|
||||
@@ -3,28 +3,25 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
||||
fileGeneratedResolvConfSearchBeginContent = "search "
|
||||
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
|
||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||
)
|
||||
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
||||
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
|
||||
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
|
||||
|
||||
const (
|
||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||
fileMaxLineCharsLimit = 256
|
||||
fileMaxNumberOfSearchDomains = 6
|
||||
)
|
||||
|
||||
var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
|
||||
fileMaxLineCharsLimit = 256
|
||||
fileMaxNumberOfSearchDomains = 6
|
||||
)
|
||||
|
||||
type fileConfigurator struct {
|
||||
originalPerms os.FileMode
|
||||
@@ -54,53 +51,39 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
}
|
||||
managerType, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch managerType {
|
||||
case fileManager, netbirdManager:
|
||||
if !backupFileExist {
|
||||
err = f.backup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to backup the resolv.conf file")
|
||||
}
|
||||
}
|
||||
default:
|
||||
// todo improve this and maybe restart DNS manager from scratch
|
||||
return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent")
|
||||
}
|
||||
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
|
||||
continue
|
||||
}
|
||||
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
|
||||
// lets log all skipped domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
|
||||
continue
|
||||
}
|
||||
|
||||
searchDomains += " " + dConf.domain
|
||||
appendedDomains++
|
||||
}
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
||||
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
|
||||
if err != nil {
|
||||
err = f.restore()
|
||||
if !backupFileExist {
|
||||
err = f.backup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to backup the resolv.conf file")
|
||||
}
|
||||
}
|
||||
|
||||
searchDomainList := searchDomains(config)
|
||||
|
||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
|
||||
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
append([]string{config.serverIP}, nameServers...),
|
||||
others)
|
||||
|
||||
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
||||
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
||||
if err != nil {
|
||||
restoreErr := f.restore()
|
||||
if restoreErr != nil {
|
||||
log.Errorf("attempt to restore default file failed with error: %s", err)
|
||||
}
|
||||
return err
|
||||
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)
|
||||
|
||||
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,15 +115,138 @@ func (f *fileConfigurator) restore() error {
|
||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
|
||||
log.Debugf("creating managed file %s", fileName)
|
||||
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(content)
|
||||
err := os.WriteFile(fileName, buf.Bytes(), permissions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
|
||||
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
||||
|
||||
for _, cfgLine := range others {
|
||||
buf.WriteString(cfgLine)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
return nil
|
||||
|
||||
if len(searchDomains) > 0 {
|
||||
buf.WriteString("search ")
|
||||
buf.WriteString(strings.Join(searchDomains, " "))
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, ns := range nameServers {
|
||||
buf.WriteString("nameserver ")
|
||||
buf.WriteString(ns)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func searchDomains(config hostDNSConfig) []string {
|
||||
listOfDomains := make([]string, 0)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, dConf.domain)
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
||||
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
|
||||
file, err := os.Open(resolvconfFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf(`could not read existing resolv.conf`)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
reader := bufio.NewReader(file)
|
||||
|
||||
for {
|
||||
lineBytes, isPrefix, readErr := reader.ReadLine()
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if isPrefix {
|
||||
err = fmt.Errorf(`resolv.conf line too long`)
|
||||
return
|
||||
}
|
||||
|
||||
line := strings.TrimSpace(string(lineBytes))
|
||||
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "domain") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
||||
line = strings.ReplaceAll(line, "rotate", "")
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) == 1 {
|
||||
continue
|
||||
}
|
||||
line = strings.Join(splitLines, " ")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "search") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
searchDomains = splitLines[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) != 2 {
|
||||
continue
|
||||
}
|
||||
nameServers = append(nameServers, splitLines[1])
|
||||
continue
|
||||
}
|
||||
|
||||
others = append(others, line)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// merge search domains lists and cut off the list if it is too long
|
||||
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
||||
lineSize := len("search")
|
||||
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
|
||||
|
||||
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
|
||||
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
|
||||
|
||||
return searchDomainsList
|
||||
}
|
||||
|
||||
// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long
|
||||
// extend s slice with vs elements
|
||||
// return with the number of characters in the searchDomains line
|
||||
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
||||
for _, sd := range vs {
|
||||
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
||||
if tmpCharsNumber > fileMaxLineCharsLimit {
|
||||
// lets log all skipped domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
|
||||
continue
|
||||
}
|
||||
|
||||
initialLineChars = tmpCharsNumber
|
||||
|
||||
if len(*s) >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
|
||||
continue
|
||||
}
|
||||
*s = append(*s, sd)
|
||||
}
|
||||
return initialLineChars
|
||||
}
|
||||
|
||||
func copyFile(src, dest string) error {
|
||||
|
||||
62
client/internal/dns/file_linux_test.go
Normal file
62
client/internal/dns/file_linux_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_mergeSearchDomains(t *testing.T) {
|
||||
searchDomains := []string{"a", "b"}
|
||||
originDomains := []string{"a", "b"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 4 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mergeSearchTooMuchDomains(t *testing.T) {
|
||||
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
|
||||
originDomains := []string{"h", "i"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 6 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
|
||||
searchDomains := []string{"a", "b"}
|
||||
originDomains := []string{"c", "d", "e", "f", "g"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 6 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mergeSearchTooLongDomain(t *testing.T) {
|
||||
searchDomains := []string{getLongLine()}
|
||||
originDomains := []string{"b"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 1 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
|
||||
}
|
||||
|
||||
searchDomains = []string{"b"}
|
||||
originDomains = []string{getLongLine()}
|
||||
|
||||
mergedDomains = mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 1 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func getLongLine() string {
|
||||
x := "search "
|
||||
for {
|
||||
for i := 0; i <= 9; i++ {
|
||||
if len(x) > fileMaxLineCharsLimit {
|
||||
return x
|
||||
}
|
||||
x = fmt.Sprintf("%s%d", x, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
|
||||
for _, domain := range nsConfig.Domains {
|
||||
config.domains = append(config.domains, domainConfig{
|
||||
domain: strings.TrimSuffix(domain, "."),
|
||||
matchOnly: true,
|
||||
matchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type androidHostManager struct {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,7 +32,7 @@ type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
}
|
||||
|
||||
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(_ WGIface) (hostManager, error) {
|
||||
return &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
@@ -184,12 +182,11 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||
primaryServiceKey := s.getPrimaryService()
|
||||
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
||||
if primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key")
|
||||
}
|
||||
|
||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
|
||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -198,27 +195,32 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getPrimaryService() string {
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||
line := buildCommandLine("show", globalIPv4State, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
b, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
log.Error("got error while sending the command: ", err)
|
||||
return ""
|
||||
return "", ""
|
||||
}
|
||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||
primaryService := ""
|
||||
router := ""
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if strings.Contains(text, "PrimaryService") {
|
||||
return strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
}
|
||||
if strings.Contains(text, "Router") {
|
||||
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return primaryService, router
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
|
||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
|
||||
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
||||
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
||||
stdinCommands := wrapCommand(addDomainCommand)
|
||||
|
||||
@@ -5,10 +5,10 @@ package dns
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
|
||||
type osManagerType int
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,16 +22,14 @@ const (
|
||||
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
existingSearchDomains []string
|
||||
guid string
|
||||
routingAll bool
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,30 +146,11 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
return r.updateSearchDomains([]string{})
|
||||
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
|
||||
}
|
||||
|
||||
valueList := strings.Split(value, ",")
|
||||
setExisting := false
|
||||
if len(r.existingSearchDomains) == 0 {
|
||||
r.existingSearchDomains = valueList
|
||||
setExisting = true
|
||||
}
|
||||
|
||||
if len(domains) == 0 && setExisting {
|
||||
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
newList := append(r.existingSearchDomains, domains...)
|
||||
|
||||
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding search domain failed with error: %s", err)
|
||||
}
|
||||
@@ -237,33 +214,3 @@ func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
|
||||
val, _, err := regKey.GetStringValue(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
|
||||
err = regKey.SetStringValue(key, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@@ -31,6 +32,11 @@ func (m *MockServer) DnsIP() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
if m.UpdateDNSServerFunc != nil {
|
||||
@@ -38,3 +44,7 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
}
|
||||
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
||||
}
|
||||
|
||||
func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
@@ -7,15 +7,13 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -124,7 +122,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...)
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
||||
priority := networkManagerDbusSearchDomainOnlyPriority
|
||||
switch {
|
||||
@@ -291,12 +289,7 @@ func isNetworkManagerSupportedVersion() bool {
|
||||
}
|
||||
|
||||
func parseVersion(inputVersion string) (*version.Version, error) {
|
||||
reg, err := regexp.Compile(version.SemverRegexpRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if inputVersion == "" || !reg.MatchString(inputVersion) {
|
||||
if inputVersion == "" || !nbversion.SemverRegexp.MatchString(inputVersion) {
|
||||
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
|
||||
}
|
||||
|
||||
|
||||
57
client/internal/dns/notifier.go
Normal file
57
client/internal/dns/notifier.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
searchDomains []string
|
||||
}
|
||||
|
||||
func newNotifier(initialSearchDomains []string) *notifier {
|
||||
sort.Strings(initialSearchDomains)
|
||||
return ¬ifier{
|
||||
searchDomains: initialSearchDomains,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *notifier) onNewSearchDomains(searchDomains []string) {
|
||||
sort.Strings(searchDomains)
|
||||
|
||||
if len(n.searchDomains) != len(searchDomains) {
|
||||
n.searchDomains = searchDomains
|
||||
n.notify()
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(n.searchDomains, searchDomains) {
|
||||
return
|
||||
}
|
||||
|
||||
n.searchDomains = searchDomains
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged()
|
||||
}(n.listener)
|
||||
}
|
||||
@@ -3,24 +3,35 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
|
||||
type resolvconf struct {
|
||||
ifaceName string
|
||||
|
||||
originalSearchDomains []string
|
||||
originalNameServers []string
|
||||
othersConfigs []string
|
||||
}
|
||||
|
||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
// supported "openresolv" only
|
||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface.Name(),
|
||||
ifaceName: wgInterface.Name(),
|
||||
originalSearchDomains: originalSearchDomains,
|
||||
originalNameServers: nameServers,
|
||||
othersConfigs: others,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -38,37 +49,20 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
||||
}
|
||||
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
searchDomainList := searchDomains(config)
|
||||
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
||||
|
||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
|
||||
continue
|
||||
}
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
append([]string{config.serverIP}, r.originalNameServers...),
|
||||
r.othersConfigs)
|
||||
|
||||
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
|
||||
// lets log all skipped domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
|
||||
continue
|
||||
}
|
||||
|
||||
searchDomains += " " + dConf.domain
|
||||
appendedDomains++
|
||||
}
|
||||
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
|
||||
|
||||
err = r.applyConfig(content)
|
||||
err = r.applyConfig(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains)
|
||||
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -81,12 +75,12 @@ func (r *resolvconf) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyConfig(content string) error {
|
||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
cmd.Stdin = strings.NewReader(content)
|
||||
cmd.Stdin = &content
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,29 +3,21 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
}
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
@@ -33,6 +25,8 @@ type Server interface {
|
||||
Stop()
|
||||
DnsIP() string
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@@ -42,21 +36,22 @@ type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
fakeResolverWG sync.WaitGroup
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
enabled bool
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
hostsDnsList []string
|
||||
hostsDnsListLock sync.Mutex
|
||||
|
||||
// make sense on mobile only
|
||||
searchDomainNotifier *notifier
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
@@ -70,9 +65,7 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -82,34 +75,47 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
addrPort = &parsedAddrPort
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = newServiceViaMemory(wgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
ds.permanent = true
|
||||
ds.hostsDnsList = hostsDnsList
|
||||
ds.addHostRootZone()
|
||||
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
|
||||
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
|
||||
ds.searchDomainNotifier.setListener(listener)
|
||||
setServerDns(ds)
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
dnsMux: mux,
|
||||
service: dnsService,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
customAddress: addrPort,
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
|
||||
if initialDnsCfg != nil {
|
||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
||||
}
|
||||
|
||||
defaultServer.evalRuntimeAddress()
|
||||
return defaultServer, nil
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager. It required to be initialized wginterface
|
||||
// Initialize instantiate host manager and the dns service
|
||||
func (s *DefaultServer) Initialize() (err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -118,74 +124,23 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.permanent {
|
||||
err = s.service.Listen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.hostManager, err = newHostManager(s.wgInterface)
|
||||
return
|
||||
}
|
||||
|
||||
// listen runs the listener in a go routine
|
||||
func (s *DefaultServer) listen() {
|
||||
// nil check required in unit tests
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.fakeResolverWG.Add(1)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
hookID := s.filterDNSTraffic()
|
||||
s.fakeResolverWG.Wait()
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// DnsIP returns the DNS resolver server IP address
|
||||
//
|
||||
// When kernel space interface used it return real DNS server listener IP address
|
||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||
func (s *DefaultServer) DnsIP() string {
|
||||
if !s.enabled {
|
||||
return ""
|
||||
}
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
@@ -194,34 +149,30 @@ func (s *DefaultServer) Stop() {
|
||||
defer s.mux.Unlock()
|
||||
s.ctxCancel()
|
||||
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if s.hostManager != nil {
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
s.fakeResolverWG.Done()
|
||||
}
|
||||
|
||||
err = s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
s.hostsDnsListLock.Lock()
|
||||
defer s.hostsDnsListLock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
s.hostsDnsList = hostsDnsList
|
||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||
if ok {
|
||||
log.Debugf("on new host DNS config but skip to apply it")
|
||||
return
|
||||
}
|
||||
return nil
|
||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||
s.addHostRootZone()
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
@@ -269,19 +220,28 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) SearchDomains() []string {
|
||||
var searchDomains []string
|
||||
|
||||
for _, dConf := range s.currentConfig.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
}
|
||||
return searchDomains
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.fakeResolverWG.Done()
|
||||
} else {
|
||||
if err := s.stopListener(); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.listen()
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
} else if !s.permanent {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
@@ -292,17 +252,16 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.routeAll = false
|
||||
}
|
||||
|
||||
@@ -310,6 +269,10 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -352,7 +315,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
@@ -370,7 +333,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||
// because it is temporal deactivation until next try
|
||||
//
|
||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
@@ -405,19 +368,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
var isContainRootUpdate bool
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
s.service.RegisterMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
}
|
||||
|
||||
if update.domain == nbdns.RootZone {
|
||||
isContainRootUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
for key, existingHandler := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
existingHandler.stop()
|
||||
s.deregisterMux(key)
|
||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||
s.hostsDnsListLock.Lock()
|
||||
s.addHostRootZone()
|
||||
s.hostsDnsListLock.Unlock()
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
s.service.DeregisterMux(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,14 +424,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
@@ -483,7 +451,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
s.service.DeregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
@@ -500,7 +468,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
s.service.RegisterMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@@ -516,93 +484,24 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) filterDNSTraffic() string {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
log.Error("can't set DNS filter, filter not initialized")
|
||||
return ""
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
func (s *DefaultServer) addHostRootZone() {
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||
for n, ua := range s.hostsDnsList {
|
||||
a, err := netip.ParseAddr(ua)
|
||||
if err != nil {
|
||||
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
|
||||
continue
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
ipString := ua
|
||||
if !a.Is4() {
|
||||
ipString = fmt.Sprintf("[%s]", ua)
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evalRuntimeAddress() {
|
||||
defer func() {
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}()
|
||||
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||
s.runtimePort = defaultPort
|
||||
return
|
||||
}
|
||||
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
return
|
||||
}
|
||||
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
|
||||
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
||||
for _, c := range cfg.NameServerGroups {
|
||||
if c.Primary {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
||||
}
|
||||
handler.deactivate = func() {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
|
||||
29
client/internal/dns/server_export.go
Normal file
29
client/internal/dns/server_export.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.Mutex
|
||||
server Server
|
||||
)
|
||||
|
||||
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
||||
func GetServerDns() (Server, error) {
|
||||
mutex.Lock()
|
||||
if server == nil {
|
||||
mutex.Unlock()
|
||||
return nil, fmt.Errorf("DNS server not instantiated yet")
|
||||
}
|
||||
s := server
|
||||
mutex.Unlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func setServerDns(newServerServer Server) {
|
||||
mutex.Lock()
|
||||
server = newServerServer
|
||||
defer mutex.Unlock()
|
||||
}
|
||||
24
client/internal/dns/server_export_test.go
Normal file
24
client/internal/dns/server_export_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetServerDns(t *testing.T) {
|
||||
_, err := GetServerDns()
|
||||
if err == nil {
|
||||
t.Errorf("invalid dns server instance")
|
||||
}
|
||||
|
||||
srv := &MockServer{}
|
||||
setServerDns(srv)
|
||||
|
||||
srvB, err := GetServerDns()
|
||||
if err != nil {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
if srvB != srv {
|
||||
t.Errorf("mismatch dns instances")
|
||||
}
|
||||
}
|
||||
@@ -5,17 +5,59 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
type mocWGIface struct {
|
||||
filter iface.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() iface.WGAddress {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||
w.filter = filter
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -26,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
@@ -221,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -239,9 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
// pretend we are running
|
||||
dnsServer.listenerIsRunning = true
|
||||
dnsServer.fakeResolverWG.Add(1)
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
@@ -276,6 +320,133 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||
if err != nil {
|
||||
t.Errorf("parse CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -292,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
||||
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
dnsServer.listen()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if !dnsServer.listenerIsRunning {
|
||||
t.Fatal("dns server listener is not running")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
err = dnsServer.service.Listen()
|
||||
if err != nil {
|
||||
t.Fatalf("dns server is not running: %s", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
@@ -314,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
@@ -349,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
dnsMux: dns.DefaultServeMux,
|
||||
service: newServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
@@ -412,62 +585,239 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||
mux := dns.NewServeMux()
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
var parsedAddrPort *netip.AddrPort
|
||||
if addrPort != "" {
|
||||
parsed, err := netip.ParseAddrPort(addrPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedAddrPort = &parsed
|
||||
var dnsList []string
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
// check initial state
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
ds := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Enabled: true,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
customAddress: parsedAddrPort,
|
||||
}
|
||||
ds.evalRuntimeAddress()
|
||||
return ds
|
||||
}
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
err = dnsServer.UpdateDNSServer(1, update)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
|
||||
update2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{},
|
||||
}
|
||||
|
||||
err = dnsServer.UpdateDNSServer(2, update2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
// check initial state
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = dnsServer.UpdateDNSServer(1, update)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
t.Helper()
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("create and init wireguard interface: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wgIface.SetFilter(pf)
|
||||
if err != nil {
|
||||
t.Fatalf("set packet filter: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wgIface, nil
|
||||
}
|
||||
|
||||
func newDnsResolver(ip string, port int) *net.Resolver {
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 3,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||
return d.DialContext(ctx, network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
18
client/internal/dns/service.go
Normal file
18
client/internal/dns/service.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
)
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
RuntimeIP() string
|
||||
}
|
||||
193
client/internal/dns/service_listener.go
Normal file
193
client/internal/dns/service_listener.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
|
||||
type serviceViaListener struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
listenIP string
|
||||
listenPort int
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.listenIP, s.listenPort, err = s.evalListenAddress()
|
||||
if err != nil {
|
||||
log.Errorf("failed to eval runtime address: %s", err)
|
||||
return err
|
||||
}
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||
|
||||
if s.shouldApplyPortFwd() {
|
||||
s.ebpfService = ebpf.GetEbpfManagerInstance()
|
||||
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
|
||||
if err != nil {
|
||||
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
|
||||
s.ebpfService = nil
|
||||
}
|
||||
}
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimePort() int {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.ebpfService != nil {
|
||||
return defaultPort
|
||||
} else {
|
||||
return s.listenPort
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimeIP() string {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||
}
|
||||
|
||||
return s.getFirstListenerAvailable()
|
||||
}
|
||||
|
||||
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
|
||||
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
|
||||
// resolution won't work.
|
||||
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
|
||||
// traffic on port 53 and forward it to a local DNS server running on 5053.
|
||||
func (s *serviceViaListener) shouldApplyPortFwd() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.customAddr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.listenPort == defaultPort {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
139
client/internal/dns/service_memory.go
Normal file
139
client/internal/dns/service_memory.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type serviceViaMemory struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP string
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
s := &serviceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
31
client/internal/dns/service_memory_test.go
Normal file
31
client/internal/dns/service_memory_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -53,7 +53,7 @@ func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) stop() {
|
||||
log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers)
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
}
|
||||
|
||||
|
||||
14
client/internal/dns/wgiface.go
Normal file
14
client/internal/dns/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !windows
|
||||
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
}
|
||||
13
client/internal/dns/wgiface_windows.go
Normal file
13
client/internal/dns/wgiface_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
129
client/internal/ebpf/ebpf/bpf_bpfeb.go
Normal file
129
client/internal/ebpf/ebpf/bpf_bpfeb.go
Normal file
@@ -0,0 +1,129 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
|
||||
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
NbFeatures *ebpf.Map `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.NbFeatures,
|
||||
m.NbMapDnsIp,
|
||||
m.NbMapDnsPort,
|
||||
m.NbWgProxySettingsMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.NbXdpProg,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfeb.o
|
||||
var _BpfBytes []byte
|
||||
BIN
client/internal/ebpf/ebpf/bpf_bpfeb.o
Normal file
BIN
client/internal/ebpf/ebpf/bpf_bpfeb.o
Normal file
Binary file not shown.
129
client/internal/ebpf/ebpf/bpf_bpfel.go
Normal file
129
client/internal/ebpf/ebpf/bpf_bpfel.go
Normal file
@@ -0,0 +1,129 @@
|
||||
// Code generated by bpf2go; DO NOT EDIT.
|
||||
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
|
||||
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cilium/ebpf"
|
||||
)
|
||||
|
||||
// loadBpf returns the embedded CollectionSpec for bpf.
|
||||
func loadBpf() (*ebpf.CollectionSpec, error) {
|
||||
reader := bytes.NewReader(_BpfBytes)
|
||||
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't load bpf: %w", err)
|
||||
}
|
||||
|
||||
return spec, err
|
||||
}
|
||||
|
||||
// loadBpfObjects loads bpf and converts it into a struct.
|
||||
//
|
||||
// The following types are suitable as obj argument:
|
||||
//
|
||||
// *bpfObjects
|
||||
// *bpfPrograms
|
||||
// *bpfMaps
|
||||
//
|
||||
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
|
||||
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
|
||||
spec, err := loadBpf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return spec.LoadAndAssign(obj, opts)
|
||||
}
|
||||
|
||||
// bpfSpecs contains maps and programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfSpecs struct {
|
||||
bpfProgramSpecs
|
||||
bpfMapSpecs
|
||||
}
|
||||
|
||||
// bpfSpecs contains programs before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
// bpfObjects contains all objects after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfObjects struct {
|
||||
bpfPrograms
|
||||
bpfMaps
|
||||
}
|
||||
|
||||
func (o *bpfObjects) Close() error {
|
||||
return _BpfClose(
|
||||
&o.bpfPrograms,
|
||||
&o.bpfMaps,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfMaps contains all maps after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
NbFeatures *ebpf.Map `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.NbFeatures,
|
||||
m.NbMapDnsIp,
|
||||
m.NbMapDnsPort,
|
||||
m.NbWgProxySettingsMap,
|
||||
)
|
||||
}
|
||||
|
||||
// bpfPrograms contains all programs after they have been loaded into the kernel.
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.NbXdpProg,
|
||||
)
|
||||
}
|
||||
|
||||
func _BpfClose(closers ...io.Closer) error {
|
||||
for _, closer := range closers {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not access this directly.
|
||||
//
|
||||
//go:embed bpf_bpfel.o
|
||||
var _BpfBytes []byte
|
||||
BIN
client/internal/ebpf/ebpf/bpf_bpfel.o
Normal file
BIN
client/internal/ebpf/ebpf/bpf_bpfel.o
Normal file
Binary file not shown.
51
client/internal/ebpf/ebpf/dns_fwd_linux.go
Normal file
51
client/internal/ebpf/ebpf/dns_fwd_linux.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
mapKeyDNSIP uint32 = 0
|
||||
mapKeyDNSPort uint32 = 1
|
||||
)
|
||||
|
||||
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
|
||||
err := tf.loadXdp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tf.setFeatureFlag(featureFlagDnsForwarder)
|
||||
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) FreeDNSFwd() error {
|
||||
log.Debugf("free ebpf DNS forwarder")
|
||||
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||
}
|
||||
|
||||
func ip2int(ipString string) uint32 {
|
||||
ip := net.ParseIP(ipString)
|
||||
return binary.BigEndian.Uint32(ip.To4())
|
||||
}
|
||||
116
client/internal/ebpf/ebpf/manager_linux.go
Normal file
116
client/internal/ebpf/ebpf/manager_linux.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
mapKeyFeatures uint32 = 0
|
||||
|
||||
featureFlagWGProxy = 0b00000001
|
||||
featureFlagDnsForwarder = 0b00000010
|
||||
)
|
||||
|
||||
var (
|
||||
singleton manager.Manager
|
||||
singletonLock = &sync.Mutex{}
|
||||
)
|
||||
|
||||
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
|
||||
|
||||
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
|
||||
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
|
||||
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
|
||||
//
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
|
||||
type GeneralManager struct {
|
||||
lock sync.Mutex
|
||||
link link.Link
|
||||
featureFlags uint16
|
||||
bpfObjs bpfObjects
|
||||
}
|
||||
|
||||
// GetEbpfManagerInstance return a static eBpf Manager instance
|
||||
func GetEbpfManagerInstance() manager.Manager {
|
||||
singletonLock.Lock()
|
||||
defer singletonLock.Unlock()
|
||||
if singleton != nil {
|
||||
return singleton
|
||||
}
|
||||
singleton = &GeneralManager{}
|
||||
return singleton
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
|
||||
tf.featureFlags |= feature
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) loadXdp() error {
|
||||
if tf.link != nil {
|
||||
return nil
|
||||
}
|
||||
// it required for Docker
|
||||
err := rlimit.RemoveMemlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iFace, err := net.InterfaceByName("lo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load pre-compiled programs into the kernel.
|
||||
err = loadBpfObjects(&tf.bpfObjs, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tf.link, err = link.AttachXDP(link.XDPOptions{
|
||||
Program: tf.bpfObjs.NbXdpProg,
|
||||
Interface: iFace.Index,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
_ = tf.bpfObjs.Close()
|
||||
tf.link = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
tf.featureFlags &^= feature
|
||||
|
||||
if tf.link == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tf.featureFlags == 0 {
|
||||
return tf.close()
|
||||
}
|
||||
|
||||
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) close() error {
|
||||
log.Debugf("detach ebpf program ")
|
||||
err := tf.bpfObjs.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close eBpf objects: %s", err)
|
||||
}
|
||||
|
||||
err = tf.link.Close()
|
||||
tf.link = nil
|
||||
return err
|
||||
}
|
||||
40
client/internal/ebpf/ebpf/manager_linux_test.go
Normal file
40
client/internal/ebpf/ebpf/manager_linux_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManager_setFeatureFlag(t *testing.T) {
|
||||
mgr := GeneralManager{}
|
||||
mgr.setFeatureFlag(featureFlagWGProxy)
|
||||
if mgr.featureFlags != 1 {
|
||||
t.Errorf("invalid feature state")
|
||||
}
|
||||
|
||||
mgr.setFeatureFlag(featureFlagDnsForwarder)
|
||||
if mgr.featureFlags != 3 {
|
||||
t.Errorf("invalid feature state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_unsetFeatureFlag(t *testing.T) {
|
||||
mgr := GeneralManager{}
|
||||
mgr.setFeatureFlag(featureFlagWGProxy)
|
||||
mgr.setFeatureFlag(featureFlagDnsForwarder)
|
||||
|
||||
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
if mgr.featureFlags != 2 {
|
||||
t.Errorf("invalid feature state, expected: %d, got: %d", 2, mgr.featureFlags)
|
||||
}
|
||||
|
||||
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
if mgr.featureFlags != 0 {
|
||||
t.Errorf("invalid feature state, expected: %d, got: %d", 0, mgr.featureFlags)
|
||||
}
|
||||
}
|
||||
64
client/internal/ebpf/ebpf/src/dns_fwd.c
Normal file
64
client/internal/ebpf/ebpf/src/dns_fwd.c
Normal file
@@ -0,0 +1,64 @@
|
||||
const __u32 map_key_dns_ip = 0;
|
||||
const __u32 map_key_dns_port = 1;
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_map_dns_ip = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u32),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_map_dns_port = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
__be32 dns_ip = 0;
|
||||
__be16 dns_port = 0;
|
||||
|
||||
// 13568 is 53 in big endian
|
||||
__be16 GENERAL_DNS_PORT = 13568;
|
||||
|
||||
bool read_settings() {
|
||||
__u16 *port_value;
|
||||
__u32 *ip_value;
|
||||
|
||||
// read dns ip
|
||||
ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip);
|
||||
if(!ip_value) {
|
||||
return false;
|
||||
}
|
||||
dns_ip = htonl(*ip_value);
|
||||
|
||||
// read dns port
|
||||
port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port);
|
||||
if (!port_value) {
|
||||
return false;
|
||||
}
|
||||
dns_port = htons(*port_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) {
|
||||
if (dns_port == 0) {
|
||||
if(!read_settings()){
|
||||
return XDP_PASS;
|
||||
}
|
||||
bpf_printk("dns port: %d", ntohs(dns_port));
|
||||
bpf_printk("dns ip: %d", ntohl(dns_ip));
|
||||
}
|
||||
|
||||
if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) {
|
||||
udp->dest = dns_port;
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (udp->source == dns_port && ip->saddr == dns_ip) {
|
||||
udp->source = GENERAL_DNS_PORT;
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
return XDP_PASS;
|
||||
}
|
||||
66
client/internal/ebpf/ebpf/src/prog.c
Normal file
66
client/internal/ebpf/ebpf/src/prog.c
Normal file
@@ -0,0 +1,66 @@
|
||||
#include <stdbool.h>
|
||||
#include <linux/if_ether.h> // ETH_P_IP
|
||||
#include <linux/udp.h>
|
||||
#include <linux/ip.h>
|
||||
#include <netinet/in.h>
|
||||
#include <linux/bpf.h>
|
||||
#include <bpf/bpf_helpers.h>
|
||||
#include "dns_fwd.c"
|
||||
#include "wg_proxy.c"
|
||||
|
||||
#define bpf_printk(fmt, ...) \
|
||||
({ \
|
||||
char ____fmt[] = fmt; \
|
||||
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
|
||||
})
|
||||
|
||||
const __u16 flag_feature_wg_proxy = 0b01;
|
||||
const __u16 flag_feature_dns_fwd = 0b10;
|
||||
|
||||
const __u32 map_key_features = 0;
|
||||
struct bpf_map_def SEC("maps") nb_features = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
SEC("xdp")
|
||||
int nb_xdp_prog(struct xdp_md *ctx) {
|
||||
__u16 *features;
|
||||
features = bpf_map_lookup_elem(&nb_features, &map_key_features);
|
||||
if (!features) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
void *data = (void *)(long)ctx->data;
|
||||
void *data_end = (void *)(long)ctx->data_end;
|
||||
struct ethhdr *eth = data;
|
||||
struct iphdr *ip = (data + sizeof(struct ethhdr));
|
||||
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
|
||||
|
||||
// return early if not enough data
|
||||
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// skip non IPv4 packages
|
||||
if (eth->h_proto != htons(ETH_P_IP)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// skip non UPD packages
|
||||
if (ip->protocol != IPPROTO_UDP) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (*features & flag_feature_dns_fwd) {
|
||||
xdp_dns_fwd(ip, udp);
|
||||
}
|
||||
|
||||
if (*features & flag_feature_wg_proxy) {
|
||||
xdp_wg_proxy(ip, udp);
|
||||
}
|
||||
return XDP_PASS;
|
||||
}
|
||||
char _license[] SEC("license") = "GPL";
|
||||
54
client/internal/ebpf/ebpf/src/wg_proxy.c
Normal file
54
client/internal/ebpf/ebpf/src/wg_proxy.c
Normal file
@@ -0,0 +1,54 @@
|
||||
const __u32 map_key_proxy_port = 0;
|
||||
const __u32 map_key_wg_port = 1;
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
__u16 proxy_port = 0;
|
||||
__u16 wg_port = 0;
|
||||
|
||||
bool read_port_settings() {
|
||||
__u16 *value;
|
||||
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
|
||||
if (!value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
proxy_port = *value;
|
||||
|
||||
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
|
||||
if (!value) {
|
||||
return false;
|
||||
}
|
||||
wg_port = htons(*value);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) {
|
||||
if (proxy_port == 0 || wg_port == 0) {
|
||||
if (!read_port_settings()){
|
||||
return XDP_PASS;
|
||||
}
|
||||
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
|
||||
}
|
||||
|
||||
// 2130706433 = 127.0.0.1
|
||||
if (ip->daddr != htonl(2130706433)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (udp->source != wg_port){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
__be16 new_src_port = udp->dest;
|
||||
__be16 new_dst_port = htons(proxy_port);
|
||||
udp->dest = new_dst_port;
|
||||
udp->source = new_src_port;
|
||||
return XDP_PASS;
|
||||
}
|
||||
41
client/internal/ebpf/ebpf/wg_proxy_linux.go
Normal file
41
client/internal/ebpf/ebpf/wg_proxy_linux.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package ebpf
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
const (
|
||||
mapKeyProxyPort uint32 = 0
|
||||
mapKeyWgPort uint32 = 1
|
||||
)
|
||||
|
||||
func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error {
|
||||
log.Debugf("load ebpf WG proxy")
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
|
||||
err := tf.loadXdp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tf.setFeatureFlag(featureFlagWGProxy)
|
||||
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) FreeWGProxy() error {
|
||||
log.Debugf("free ebpf WG proxy")
|
||||
return tf.unsetFeatureFlag(featureFlagWGProxy)
|
||||
}
|
||||
15
client/internal/ebpf/instantiater_linux.go
Normal file
15
client/internal/ebpf/instantiater_linux.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal
|
||||
// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant
|
||||
// panic on older Android version.
|
||||
func GetEbpfManagerInstance() manager.Manager {
|
||||
return ebpf.GetEbpfManagerInstance()
|
||||
}
|
||||
10
client/internal/ebpf/instantiater_nonlinux.go
Normal file
10
client/internal/ebpf/instantiater_nonlinux.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package ebpf
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
|
||||
// GetEbpfManagerInstance return error because ebpf is not supported on all os
|
||||
func GetEbpfManagerInstance() manager.Manager {
|
||||
panic("unsupported os")
|
||||
}
|
||||
9
client/internal/ebpf/manager/manager.go
Normal file
9
client/internal/ebpf/manager/manager.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package manager
|
||||
|
||||
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
||||
type Manager interface {
|
||||
LoadDNSFwd(ip string, dnsPort int) error
|
||||
FreeDNSFwd() error
|
||||
LoadWgProxy(proxyPort, wgPort int) error
|
||||
FreeWGProxy() error
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user