mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 07:04:17 -04:00
Compare commits
161 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be209e7841 | ||
|
|
7c22a4ba9b | ||
|
|
f86aa47933 | ||
|
|
aafa342786 | ||
|
|
042141db06 | ||
|
|
4a1aee1ae0 | ||
|
|
ba33572ec9 | ||
|
|
9d213e0b54 | ||
|
|
5dde044fa5 | ||
|
|
5a3d9e401f | ||
|
|
fde1a2196c | ||
|
|
0aeb87742a | ||
|
|
6d747b2f83 | ||
|
|
199bf73103 | ||
|
|
17f5abc653 | ||
|
|
aa935bdae3 | ||
|
|
452419c4c3 | ||
|
|
17b1099032 | ||
|
|
a4b9e93217 | ||
|
|
63d7957140 | ||
|
|
9a6814deff | ||
|
|
190698bcf2 | ||
|
|
468fa2940b | ||
|
|
79a0647a26 | ||
|
|
17ceb3bde8 | ||
|
|
5a8f1763a6 | ||
|
|
f64e73ca70 | ||
|
|
b085419ab8 | ||
|
|
d78b652ff7 | ||
|
|
7251150c1c | ||
|
|
b65c2f69b0 | ||
|
|
d8ce08d898 | ||
|
|
e1c50248d9 | ||
|
|
ce2d14c08e | ||
|
|
52fd9a575a | ||
|
|
9028c3c1f7 | ||
|
|
9357a587e9 | ||
|
|
a47c69c472 | ||
|
|
bbea4c3cc3 | ||
|
|
b7a6cbfaa5 | ||
|
|
e18bf565a2 | ||
|
|
51fa3c92c5 | ||
|
|
d65602f904 | ||
|
|
8d9e1fed5f | ||
|
|
e1eddd1cab | ||
|
|
0fbf72434e | ||
|
|
51f133fdc6 | ||
|
|
d5338c09dc | ||
|
|
8fd4166c53 | ||
|
|
9bc7b9e897 | ||
|
|
db3cba5e0f | ||
|
|
cb3408a10b | ||
|
|
0afd738509 | ||
|
|
cf87f1e702 | ||
|
|
e890fdae54 | ||
|
|
dd14db6478 | ||
|
|
88747e3e01 | ||
|
|
fb30931365 | ||
|
|
a7547b9990 | ||
|
|
62bacee8dc | ||
|
|
71cd2e3e03 | ||
|
|
bdf71ab7ff | ||
|
|
a2f2a6e21a | ||
|
|
f89332fcd2 | ||
|
|
8604add997 | ||
|
|
93cab49696 | ||
|
|
b6835d9467 | ||
|
|
846d486366 | ||
|
|
9c56f74235 | ||
|
|
25b3641be8 | ||
|
|
c41504b571 | ||
|
|
399493a954 | ||
|
|
4771fed64f | ||
|
|
88117f7d16 | ||
|
|
5ac9f9fe2f | ||
|
|
a7d6632298 | ||
|
|
d4194cba6a | ||
|
|
131d9f1bc7 | ||
|
|
f099e02b34 | ||
|
|
93646e6a13 | ||
|
|
67a2127fd7 | ||
|
|
dd7fcbd083 | ||
|
|
d5f330b9c0 | ||
|
|
9fa0fbda0d | ||
|
|
5a7aa461de | ||
|
|
e9c967b27c | ||
|
|
ace588758c | ||
|
|
8bb16e016c | ||
|
|
6a2a97f088 | ||
|
|
3591795a58 | ||
|
|
5311ce4e4a | ||
|
|
c61cb00f40 | ||
|
|
72a1e97304 | ||
|
|
5242851ecc | ||
|
|
cb69348a30 | ||
|
|
69dbcbd362 | ||
|
|
5de4acf2fe | ||
|
|
aa3b79d311 | ||
|
|
8b4ec96516 | ||
|
|
1f3a12d941 | ||
|
|
1de3bb5420 | ||
|
|
163933d429 | ||
|
|
875a2e2b63 | ||
|
|
fd8bba6aa3 | ||
|
|
86908eee58 | ||
|
|
c1caec3fcb | ||
|
|
b28b8fce50 | ||
|
|
f780f17f85 | ||
|
|
5903715a61 | ||
|
|
5469de53c5 | ||
|
|
bc3d647d6b | ||
|
|
7060b63838 | ||
|
|
3168b80ad0 | ||
|
|
818c6b885f | ||
|
|
01f28baec7 | ||
|
|
56896794b3 | ||
|
|
f73a2e2848 | ||
|
|
19fa071a93 | ||
|
|
cba3c549e9 | ||
|
|
65247de48d | ||
|
|
2d1dfa3ae7 | ||
|
|
5961c8330e | ||
|
|
d275d411aa | ||
|
|
5ecafef5d2 | ||
|
|
d073a250cc | ||
|
|
a1c48468ab | ||
|
|
dd1e730454 | ||
|
|
050f140245 | ||
|
|
006ba32086 | ||
|
|
b03343bc4d | ||
|
|
36d62f1844 | ||
|
|
08733ed8d5 | ||
|
|
27ed88f918 | ||
|
|
45fc89b2c9 | ||
|
|
f822a58326 | ||
|
|
d1f13025d1 | ||
|
|
3f8b500f0b | ||
|
|
0d2db4b172 | ||
|
|
7a18dea766 | ||
|
|
ae5f69562d | ||
|
|
755ffcfc73 | ||
|
|
dc8f55f23e | ||
|
|
89249b414f | ||
|
|
92adf57fea | ||
|
|
1cd5a66575 | ||
|
|
b9fc008542 | ||
|
|
d5bf79bc51 | ||
|
|
4bf574037f | ||
|
|
47c44d4b87 | ||
|
|
96f866fb68 | ||
|
|
141065f14e | ||
|
|
8e74fb1fa8 | ||
|
|
ba96e102b4 | ||
|
|
2129b23fe7 | ||
|
|
efd05ca023 | ||
|
|
c829ad930c | ||
|
|
ad1f18a52a | ||
|
|
bab420ca77 | ||
|
|
a729c83b06 | ||
|
|
a7e55cc5e3 | ||
|
|
b7c0eba1e5 |
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.20-bullseye
|
||||
FROM golang:1.21-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||
"ghcr.io/devcontainers/features/go:1": {
|
||||
"version": "1.20"
|
||||
"version": "1.21"
|
||||
}
|
||||
},
|
||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||
|
||||
18
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
18
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -2,15 +2,17 @@
|
||||
name: Bug/Issue report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
labels: ['triage-needed']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the problem**
|
||||
|
||||
A clear and concise description of what the problem is.
|
||||
|
||||
**To Reproduce**
|
||||
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
@@ -18,13 +20,25 @@ Steps to reproduce the behavior:
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Are you using NetBird Cloud?**
|
||||
|
||||
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
|
||||
|
||||
**NetBird version**
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -d output:**
|
||||
If applicable, add the output of the `netbird status -d` command
|
||||
|
||||
If applicable, add the `netbird status -d' command output.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -2,7 +2,7 @@
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
labels: ['feature-request']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@@ -35,5 +35,8 @@ jobs:
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
|
||||
12
.github/workflows/golang-test-linux.yml
vendored
12
.github/workflows/golang-test-linux.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
@@ -41,8 +41,11 @@ jobs:
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
@@ -50,7 +53,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
@@ -69,6 +72,9 @@ jobs:
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
|
||||
7
.github/workflows/golang-test-windows.yml
vendored
7
.github/workflows/golang-test-windows.yml
vendored
@@ -23,13 +23,13 @@ jobs:
|
||||
uses: actions/setup-go@v4
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
@@ -44,9 +44,10 @@ jobs:
|
||||
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
- run: "[Environment]::SetEnvironmentVariable('NETBIRD_STORE_ENGINE', 'jsonfile', 'Machine')"
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: golangci-lint
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Android build validation
|
||||
name: Mobile build validation
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -11,7 +11,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
android_build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -19,9 +19,16 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v3
|
||||
@@ -29,13 +36,30 @@ jobs:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/tools/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android nebtird lib
|
||||
- name: build android netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
ios_build:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
38
.github/workflows/release.yml
vendored
38
.github/workflows/release.yml
vendored
@@ -20,7 +20,7 @@ on:
|
||||
- 'client/ui/**'
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.10"
|
||||
SIGN_PIPE_VER: "v0.0.11"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
|
||||
concurrency:
|
||||
@@ -44,15 +44,18 @@ jobs:
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
${{ runner.os }}-go-releaser-
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -117,14 +120,17 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-
|
||||
${{ runner.os }}-ui-go-releaser-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -169,18 +175,24 @@ jobs:
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-
|
||||
${{ runner.os }}-ui-go-releaser-darwin-
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
-
|
||||
name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
-
|
||||
name: Run GoReleaser
|
||||
id: goreleaser
|
||||
|
||||
22
.github/workflows/sync-main.yml
vendored
Normal file
22
.github/workflows/sync-main.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: sync main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
trigger_sync_main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
23
.github/workflows/sync-tag.yml
vendored
Normal file
23
.github/workflows/sync-tag.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: sync tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
trigger_sync_tag:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
45
.github/workflows/test-infrastructure-files.yml
vendored
45
.github/workflows/test-infrastructure-files.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
- name: check values
|
||||
working-directory: infrastructure_files
|
||||
working-directory: infrastructure_files/artifacts
|
||||
env:
|
||||
CI_NETBIRD_DOMAIN: localhost
|
||||
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
||||
@@ -87,8 +87,10 @@ jobs:
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
|
||||
run: |
|
||||
set -x
|
||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
@@ -107,7 +109,7 @@ jobs:
|
||||
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
|
||||
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
|
||||
grep UseIDToken management.json | grep false
|
||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
||||
@@ -120,10 +122,14 @@ jobs:
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Build management binary
|
||||
working-directory: management
|
||||
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
||||
@@ -143,7 +149,7 @@ jobs:
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
docker-compose up -d
|
||||
sleep 5
|
||||
@@ -152,9 +158,16 @@ jobs:
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("infrastructure_files")) | .State' | grep -c running)
|
||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
|
||||
test $count -eq 4
|
||||
working-directory: infrastructure_files
|
||||
working-directory: infrastructure_files/artifacts
|
||||
|
||||
- name: test geolocation databases
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -175,8 +188,24 @@ jobs:
|
||||
- name: test management.json file gen
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
run: test -f turnserver.conf
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
run: test -f dashboard.env
|
||||
run: test -f dashboard.env
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
||||
21
.gitignore
vendored
21
.gitignore
vendored
@@ -6,11 +6,20 @@ bin/
|
||||
.env
|
||||
conf.json
|
||||
http-cmds.sh
|
||||
infrastructure_files/management.json
|
||||
infrastructure_files/management-*.json
|
||||
infrastructure_files/docker-compose.yml
|
||||
infrastructure_files/openid-configuration.json
|
||||
infrastructure_files/turnserver.conf
|
||||
setup.env
|
||||
infrastructure_files/**/Caddyfile
|
||||
infrastructure_files/**/dashboard.env
|
||||
infrastructure_files/**/zitadel.env
|
||||
infrastructure_files/**/management.json
|
||||
infrastructure_files/**/management-*.json
|
||||
infrastructure_files/**/docker-compose.yml
|
||||
infrastructure_files/**/openid-configuration.json
|
||||
infrastructure_files/**/turnserver.conf
|
||||
infrastructure_files/**/management.json.bkp.**
|
||||
infrastructure_files/**/management-*.json.bkp.**
|
||||
infrastructure_files/**/docker-compose.yml.bkp.**
|
||||
infrastructure_files/**/openid-configuration.json.bkp.**
|
||||
infrastructure_files/**/turnserver.conf.bkp.**
|
||||
management/management
|
||||
client/client
|
||||
client/client.exe
|
||||
@@ -20,4 +29,4 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
*.db
|
||||
GeoLite2-City*
|
||||
@@ -63,6 +63,14 @@ linters-settings:
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
severity: warning
|
||||
disabled: false
|
||||
arguments:
|
||||
- "checkPrivateReceivers"
|
||||
- "sayRepetitiveInsteadOfStutters"
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
@@ -93,6 +101,7 @@ linters:
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
|
||||
@@ -54,7 +54,7 @@ nfpms:
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird-systemtray-default.png
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -71,7 +71,7 @@ nfpms:
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird-systemtray-default.png
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
|
||||
@@ -19,6 +19,7 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
- [Local NetBird setup](#local-netbird-setup)
|
||||
- [Dev Container Support](#dev-container-support)
|
||||
- [Build and start](#build-and-start)
|
||||
- [Test suite](#test-suite)
|
||||
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
|
||||
@@ -135,6 +136,48 @@ checked out and set up:
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
### Dev Container Support
|
||||
|
||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||
Dev containers provide a consistent and isolated development environment, making it easier for contributors to get started quickly. Follow the steps below to set up NetBird in a dev container.
|
||||
|
||||
#### 1. Prerequisites:
|
||||
|
||||
* Install Docker on your machine: [Docker Installation Guide](https://docs.docker.com/get-docker/)
|
||||
* Install Visual Studio Code: [VS Code Installation Guide](https://code.visualstudio.com/download)
|
||||
* If you prefer JetBrains Goland please follow this [manual](https://www.jetbrains.com/help/go/connect-to-devcontainer.html)
|
||||
|
||||
#### 2. Clone the Repository:
|
||||
|
||||
Clone the repository following previous [Local NetBird setup](#local-netbird-setup).
|
||||
|
||||
#### 3. Open in project in IDE of your choice:
|
||||
|
||||
**VScode**:
|
||||
|
||||
Open the project folder in Visual Studio Code:
|
||||
|
||||
```bash
|
||||
code .
|
||||
```
|
||||
|
||||
When you open the project in VS Code, it will detect the presence of a dev container configuration.
|
||||
Click on the green "Reopen in Container" button in the bottom-right corner of VS Code.
|
||||
|
||||
**Goland**:
|
||||
|
||||
Open GoLand and select `"File" > "Open"` to open the NetBird project folder.
|
||||
GoLand will detect the dev container configuration and prompt you to open the project in the container. Accept the prompt.
|
||||
|
||||
#### 4. Wait for the Container to Build:
|
||||
|
||||
VsCode or GoLand will use the specified Docker image to build the dev container. This might take some time, depending on your internet connection.
|
||||
|
||||
#### 6. Development:
|
||||
|
||||
Once the container is built, you can start developing within the dev container. All the necessary dependencies and configurations are set up within the container.
|
||||
|
||||
|
||||
### Build and start
|
||||
#### Client
|
||||
|
||||
@@ -146,6 +189,8 @@ CGO_ENABLED=0 go build .
|
||||
|
||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`.
|
||||
|
||||
> To test the client GUI application on Windows machines with RDP or vituralized environments (e.g. virtualbox or cloud), you need to download and extract the opengl32.dll from https://fdossena.com/?p=mesa/index.frag next to the built application.
|
||||
|
||||
To start NetBird the client in the foreground:
|
||||
|
||||
```
|
||||
@@ -229,6 +274,8 @@ go test -exec sudo ./...
|
||||
```
|
||||
> On Windows use a powershell with administrator privileges
|
||||
|
||||
> Non-GTK environments will need the `libayatana-appindicator3-dev` (debian/ubuntu) package installed
|
||||
|
||||
## Checklist before submitting a PR
|
||||
As a critical network service and open-source project, we must enforce a few things before submitting the pull-requests:
|
||||
- Keep functions as simple as possible, with a single purpose
|
||||
|
||||
37
README.md
37
README.md
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
|
||||
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
|
||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
@@ -42,25 +42,22 @@
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||
### Open-Source Network Security in a Single Platform
|
||||
|
||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
||||

|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Automation | Platforms |
|
||||
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[ ] iOS </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
||||
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
||||
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
@@ -109,8 +106,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
@@ -126,4 +124,5 @@ We use open-source technologies like [WireGuard®](https://www.wireguard.com/),
|
||||
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
|
||||
dddd
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM alpine:3
|
||||
FROM alpine:3.18.5
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||
|
||||
@@ -79,6 +79,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
return err
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
@@ -109,6 +110,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
return err
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
@@ -139,6 +141,11 @@ func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
}
|
||||
|
||||
// SetInfoLogLevel configure the logger to info level
|
||||
func (c *Client) SetInfoLogLevel() {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
}
|
||||
|
||||
// PeersList return with the list of the PeerInfos
|
||||
func (c *Client) PeersList() *PeerInfoArray {
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ var loginCmd = &cobra.Command{
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ var loginCmd = &cobra.Command{
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
@@ -82,12 +82,15 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
@@ -151,13 +154,21 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -25,8 +25,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,6 +56,13 @@ var (
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
serviceName string
|
||||
autoConnectDisabled bool
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
@@ -87,14 +101,21 @@ func init() {
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
}
|
||||
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
@@ -118,6 +139,10 @@ func init() {
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -168,7 +193,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
return prefix + upper
|
||||
}
|
||||
|
||||
// DialClientGRPCServer returns client connection to the dameno server.
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||
defer cancel()
|
||||
|
||||
@@ -2,8 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -24,12 +22,8 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
name := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "Netbird"
|
||||
}
|
||||
return &service.Config{
|
||||
Name: name,
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Option: make(service.KeyValue),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func (p *program) Start(svc service.Service) error {
|
||||
@@ -109,7 +110,6 @@ var runCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Printf("Netbird service is running")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ var installCmd = &cobra.Command{
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Netbird service has been installed")
|
||||
return nil
|
||||
},
|
||||
@@ -106,7 +107,7 @@ var uninstallCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println("Netbird has been uninstalled")
|
||||
cmd.Println("Netbird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -22,14 +22,20 @@ import (
|
||||
)
|
||||
|
||||
type peerStateDetailOutput struct {
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
@@ -41,11 +47,25 @@ type peersStateOutput struct {
|
||||
type signalStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type managementStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Available int `json:"available" yaml:"available"`
|
||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type iceCandidateType struct {
|
||||
@@ -53,26 +73,40 @@ type iceCandidateType struct {
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type nsServerGroupStateOutput struct {
|
||||
Servers []string `json:"servers" yaml:"servers"`
|
||||
Domains []string `json:"domains" yaml:"domains"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
prefixNamesFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -83,12 +117,14 @@ var statusCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
@@ -142,7 +178,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
case yamlFlag:
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false)
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -172,8 +208,12 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
}
|
||||
@@ -185,11 +225,26 @@ func parseFilters() error {
|
||||
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
|
||||
}
|
||||
ipsFilterMap[addr] = struct{}{}
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for _, name := range prefixNamesFilter {
|
||||
prefixNamesFilterMap[strings.ToLower(name)] = struct{}{}
|
||||
}
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableDetailFlagWhenFilterFlag() {
|
||||
if !detailFlag && !jsonFlag && !yamlFlag {
|
||||
detailFlag = true
|
||||
}
|
||||
}
|
||||
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
@@ -197,37 +252,89 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
||||
managementOverview := managementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
Connected: managementState.GetConnected(),
|
||||
Error: managementState.Error,
|
||||
}
|
||||
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
signalOverview := signalStateOutput{
|
||||
URL: signalState.GetURL(),
|
||||
Connected: signalState.GetConnected(),
|
||||
Error: signalState.Error,
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||
|
||||
overview := statusOutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
Relays: relayOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||
var relayStateDetail []relayStateOutputDetail
|
||||
|
||||
var relaysAvailable int
|
||||
for _, relay := range relays {
|
||||
available := relay.GetAvailable()
|
||||
relayStateDetail = append(relayStateDetail,
|
||||
relayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
},
|
||||
)
|
||||
|
||||
if available {
|
||||
relaysAvailable++
|
||||
}
|
||||
}
|
||||
|
||||
return relayStateOutput{
|
||||
Total: len(relays),
|
||||
Available: relaysAvailable,
|
||||
Details: relayStateDetail,
|
||||
}
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||
Servers: pbNsGroupServer.GetServers(),
|
||||
Domains: pbNsGroupServer.GetDomains(),
|
||||
Enabled: pbNsGroupServer.GetEnabled(),
|
||||
Error: pbNsGroupServer.GetError(),
|
||||
})
|
||||
}
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
connType := ""
|
||||
peersConnected := 0
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
for _, pbPeerState := range peers {
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
@@ -238,10 +345,15 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
}
|
||||
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
@@ -256,7 +368,16 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
},
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetRoutes(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
@@ -306,22 +427,31 @@ func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
|
||||
managementConnString := "Disconnected"
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
signalConnString := "Disconnected"
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
@@ -333,6 +463,64 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(overview.Routes) > 0 {
|
||||
sort.Strings(overview.Routes)
|
||||
routes = strings.Join(overview.Routes, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
}
|
||||
errorString := ""
|
||||
if nsServerGroup.Error != "" {
|
||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||
errorString = strings.TrimSpace(errorString)
|
||||
}
|
||||
|
||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||
if domainsString == "" {
|
||||
domainsString = "." // Show "." for the default zone
|
||||
}
|
||||
dnsServersString += fmt.Sprintf(
|
||||
"\n [%s] for [%s] is %s%s",
|
||||
strings.Join(nsServerGroup.Servers, ", "),
|
||||
domainsString,
|
||||
enabled,
|
||||
errorString,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
@@ -340,25 +528,33 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"Relays: %s\n"+
|
||||
"Nameservers: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers)
|
||||
summary := parseGeneralSummary(overview, true)
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
summary := parseGeneralSummary(overview, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
@@ -369,7 +565,7 @@ func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
)
|
||||
}
|
||||
|
||||
func parsePeers(peers peersStateOutput) string {
|
||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||
var (
|
||||
peersString = ""
|
||||
)
|
||||
@@ -386,6 +582,48 @@ func parsePeers(peers peersStateOutput) string {
|
||||
remoteICE = peerState.IceCandidateType.Remote
|
||||
}
|
||||
|
||||
localICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Local != "" {
|
||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
||||
}
|
||||
|
||||
remoteICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
lastStatusUpdate := "-"
|
||||
if !peerState.LastStatusUpdate.IsZero() {
|
||||
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
lastWireGuardHandshake := "-"
|
||||
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
|
||||
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
} else {
|
||||
if rosenpassPermissive {
|
||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
||||
} else {
|
||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
||||
}
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(peerState.Routes) > 0 {
|
||||
sort.Strings(peerState.Routes)
|
||||
routes = strings.Join(peerState.Routes, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
@@ -395,7 +633,12 @@ func parsePeers(peers peersStateOutput) string {
|
||||
" Connection type: %s\n"+
|
||||
" Direct: %t\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" Last connection update: %s\n",
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
@@ -404,7 +647,14 @@ func parsePeers(peers peersStateOutput) string {
|
||||
peerState.Direct,
|
||||
localICE,
|
||||
remoteICE,
|
||||
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
lastStatusUpdate,
|
||||
lastWireGuardHandshake,
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
)
|
||||
|
||||
peersString += peerString
|
||||
@@ -415,6 +665,7 @@ func parsePeers(peers peersStateOutput) string {
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -431,5 +682,39 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
ipEval = true
|
||||
}
|
||||
}
|
||||
return statusEval || ipEval
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
func toIEC(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
count := 0
|
||||
for _, server := range dnsServers {
|
||||
if server.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -25,41 +28,93 @@ var resp = &proto.StatusResponse{
|
||||
FullStatus: &proto.FullStatus{
|
||||
Peers: []*proto.PeerState{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
RemoteIceCandidateEndpoint: "",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||
BytesRx: 2000,
|
||||
BytesTx: 1000,
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: &proto.SignalState{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: []*proto.RelayState{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
LocalPeerState: &proto.LocalPeerState{
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
DnsServers: []*proto.NSGroupState{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
@@ -82,6 +137,16 @@ var overview = statusOutputOverview{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||
TransferReceived: 200,
|
||||
TransferSent: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
@@ -95,6 +160,13 @@ var overview = statusOutputOverview{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "10.0.0.1:10001",
|
||||
Remote: "10.0.10.1:10002",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||
TransferReceived: 2000,
|
||||
TransferSent: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -103,15 +175,58 @@ var overview = statusOutputOverview{
|
||||
ManagementState: managementStateOutput{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: signalStateOutput{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: relayStateOutput{
|
||||
Total: 2,
|
||||
Available: 1,
|
||||
Details: []relayStateOutputDetail{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []nsServerGroupStateOutput{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -145,107 +260,219 @@ func TestSortingOfPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
json, _ := parseToJSON(overview)
|
||||
jsonString, _ := parseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSON := "{\"" +
|
||||
"peers\":" +
|
||||
"{" +
|
||||
"\"total\":2," +
|
||||
"\"connected\":2," +
|
||||
"\"details\":" +
|
||||
"[" +
|
||||
"{" +
|
||||
"\"fqdn\":\"peer-1.awesome-domain.com\"," +
|
||||
"\"netbirdIp\":\"192.168.178.101\"," +
|
||||
"\"publicKey\":\"Pubkey1\"," +
|
||||
"\"status\":\"Connected\"," +
|
||||
"\"lastStatusUpdate\":\"2001-01-01T01:01:01Z\"," +
|
||||
"\"connectionType\":\"P2P\"," +
|
||||
"\"direct\":true," +
|
||||
"\"iceCandidateType\":" +
|
||||
"{" +
|
||||
"\"local\":\"\"," +
|
||||
"\"remote\":\"\"" +
|
||||
"}" +
|
||||
"}," +
|
||||
"{" +
|
||||
"\"fqdn\":\"peer-2.awesome-domain.com\"," +
|
||||
"\"netbirdIp\":\"192.168.178.102\"," +
|
||||
"\"publicKey\":\"Pubkey2\"," +
|
||||
"\"status\":\"Connected\"," +
|
||||
"\"lastStatusUpdate\":\"2002-02-02T02:02:02Z\"," +
|
||||
"\"connectionType\":\"Relayed\"," +
|
||||
"\"direct\":false," +
|
||||
"\"iceCandidateType\":" +
|
||||
"{" +
|
||||
"\"local\":\"relay\"," +
|
||||
"\"remote\":\"prflx\"" +
|
||||
"}" +
|
||||
"}" +
|
||||
"]" +
|
||||
"}," +
|
||||
"\"cliVersion\":\"development\"," +
|
||||
"\"daemonVersion\":\"0.14.1\"," +
|
||||
"\"management\":" +
|
||||
"{" +
|
||||
"\"url\":\"my-awesome-management.com:443\"," +
|
||||
"\"connected\":true" +
|
||||
"}," +
|
||||
"\"signal\":" +
|
||||
"{\"" +
|
||||
"url\":\"my-awesome-signal.com:443\"," +
|
||||
"\"connected\":true" +
|
||||
"}," +
|
||||
"\"netbirdIp\":\"192.168.178.100/16\"," +
|
||||
"\"publicKey\":\"Some-Pub-Key\"," +
|
||||
"\"usesKernelInterface\":true," +
|
||||
"\"fqdn\":\"some-localhost.awesome-domain.com\"" +
|
||||
"}"
|
||||
expectedJSONString := `
|
||||
{
|
||||
"peers": {
|
||||
"total": 2,
|
||||
"connected": 2,
|
||||
"details": [
|
||||
{
|
||||
"fqdn": "peer-1.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.101",
|
||||
"publicKey": "Pubkey1",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"direct": true,
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"fqdn": "peer-2.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.102",
|
||||
"publicKey": "Pubkey2",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"direct": false,
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"quantumResistance": false,
|
||||
"routes": null
|
||||
}
|
||||
]
|
||||
},
|
||||
"cliVersion": "development",
|
||||
"daemonVersion": "0.14.1",
|
||||
"management": {
|
||||
"url": "my-awesome-management.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"signal": {
|
||||
"url": "my-awesome-signal.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"relays": {
|
||||
"total": 2,
|
||||
"available": 1,
|
||||
"details": [
|
||||
{
|
||||
"uri": "stun:my-awesome-stun.com:3478",
|
||||
"available": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
"available": false,
|
||||
"error": "context: deadline exceeded"
|
||||
}
|
||||
]
|
||||
},
|
||||
"netbirdIp": "192.168.178.100/16",
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
"8.8.8.8:53"
|
||||
],
|
||||
"domains": null,
|
||||
"enabled": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"servers": [
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53"
|
||||
],
|
||||
"domains": [
|
||||
"example.com",
|
||||
"example.net"
|
||||
],
|
||||
"enabled": false,
|
||||
"error": "timeout"
|
||||
}
|
||||
]
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
assert.Equal(t, expectedJSON, json)
|
||||
var expectedJSON bytes.Buffer
|
||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
||||
|
||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := parseToYAML(overview)
|
||||
|
||||
expectedYAML := "peers:\n" +
|
||||
" total: 2\n" +
|
||||
" connected: 2\n" +
|
||||
" details:\n" +
|
||||
" - fqdn: peer-1.awesome-domain.com\n" +
|
||||
" netbirdIp: 192.168.178.101\n" +
|
||||
" publicKey: Pubkey1\n" +
|
||||
" status: Connected\n" +
|
||||
" lastStatusUpdate: 2001-01-01T01:01:01Z\n" +
|
||||
" connectionType: P2P\n" +
|
||||
" direct: true\n" +
|
||||
" iceCandidateType:\n" +
|
||||
" local: \"\"\n" +
|
||||
" remote: \"\"\n" +
|
||||
" - fqdn: peer-2.awesome-domain.com\n" +
|
||||
" netbirdIp: 192.168.178.102\n" +
|
||||
" publicKey: Pubkey2\n" +
|
||||
" status: Connected\n" +
|
||||
" lastStatusUpdate: 2002-02-02T02:02:02Z\n" +
|
||||
" connectionType: Relayed\n" +
|
||||
" direct: false\n" +
|
||||
" iceCandidateType:\n" +
|
||||
" local: relay\n" +
|
||||
" remote: prflx\n" +
|
||||
"cliVersion: development\n" +
|
||||
"daemonVersion: 0.14.1\n" +
|
||||
"management:\n" +
|
||||
" url: my-awesome-management.com:443\n" +
|
||||
" connected: true\n" +
|
||||
"signal:\n" +
|
||||
" url: my-awesome-signal.com:443\n" +
|
||||
" connected: true\n" +
|
||||
"netbirdIp: 192.168.178.100/16\n" +
|
||||
"publicKey: Some-Pub-Key\n" +
|
||||
"usesKernelInterface: true\n" +
|
||||
"fqdn: some-localhost.awesome-domain.com\n"
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
total: 2
|
||||
connected: 2
|
||||
details:
|
||||
- fqdn: peer-1.awesome-domain.com
|
||||
netbirdIp: 192.168.178.101
|
||||
publicKey: Pubkey1
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
direct: true
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
direct: false
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
url: my-awesome-management.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
signal:
|
||||
url: my-awesome-signal.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
relays:
|
||||
total: 2
|
||||
available: 1
|
||||
details:
|
||||
- uri: stun:my-awesome-stun.com:3478
|
||||
available: true
|
||||
error: ""
|
||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
||||
available: false
|
||||
error: 'context: deadline exceeded'
|
||||
netbirdIp: 192.168.178.100/16
|
||||
publicKey: Some-Pub-Key
|
||||
usesKernelInterface: true
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
domains: []
|
||||
enabled: true
|
||||
error: ""
|
||||
- servers:
|
||||
- 1.1.1.1:53
|
||||
- 2.2.2.2:53
|
||||
domains:
|
||||
- example.com
|
||||
- example.net
|
||||
enabled: false
|
||||
error: timeout
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
@@ -253,50 +480,76 @@ func TestParsingToYAML(t *testing.T) {
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := "Peers detail:\n" +
|
||||
" peer-1.awesome-domain.com:\n" +
|
||||
" NetBird IP: 192.168.178.101\n" +
|
||||
" Public key: Pubkey1\n" +
|
||||
" Status: Connected\n" +
|
||||
" -- detail --\n" +
|
||||
" Connection type: P2P\n" +
|
||||
" Direct: true\n" +
|
||||
" ICE candidate (Local/Remote): -/-\n" +
|
||||
" Last connection update: 2001-01-01 01:01:01\n" +
|
||||
"\n" +
|
||||
" peer-2.awesome-domain.com:\n" +
|
||||
" NetBird IP: 192.168.178.102\n" +
|
||||
" Public key: Pubkey2\n" +
|
||||
" Status: Connected\n" +
|
||||
" -- detail --\n" +
|
||||
" Connection type: Relayed\n" +
|
||||
" Direct: false\n" +
|
||||
" ICE candidate (Local/Remote): relay/prflx\n" +
|
||||
" Last connection update: 2002-02-02 02:02:02\n" +
|
||||
"\n" +
|
||||
"Daemon version: 0.14.1\n" +
|
||||
"CLI version: development\n" +
|
||||
"Management: Connected to my-awesome-management.com:443\n" +
|
||||
"Signal: Connected to my-awesome-signal.com:443\n" +
|
||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||
"NetBird IP: 192.168.178.100/16\n" +
|
||||
"Interface type: Kernel\n" +
|
||||
"Peers count: 2/2 Connected\n"
|
||||
expectedDetail :=
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
Public key: Pubkey1
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
Direct: true
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Last connection update: 2001-01-01 01:01:01
|
||||
Last WireGuard handshake: 2001-01-01 01:01:02
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.102
|
||||
Public key: Pubkey2
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
Direct: false
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Last connection update: 2002-02-02 02:02:02
|
||||
Last WireGuard handshake: 2002-02-02 02:02:03
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
[stun:my-awesome-stun.com:3478] is Available
|
||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||
Nameservers:
|
||||
[8.8.8.8:53] for [.] is Available
|
||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false)
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString := "Daemon version: 0.14.1\n" +
|
||||
"CLI version: development\n" +
|
||||
"Management: Connected\n" +
|
||||
"Signal: Connected\n" +
|
||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||
"NetBird IP: 192.168.178.100/16\n" +
|
||||
"Interface type: Kernel\n" +
|
||||
"Peers count: 2/2 Connected\n"
|
||||
expectedString :=
|
||||
`Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
Relays: 1/2 Available
|
||||
Nameservers: 1/2 Available
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedString, shortVersion)
|
||||
}
|
||||
|
||||
@@ -78,8 +78,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore, false)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -36,6 +38,8 @@ var (
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -85,16 +89,53 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
@@ -142,7 +183,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
@@ -152,6 +192,38 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
loginRequest.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
loginRequest.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
loginRequest.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
wp := int64(wireguardPort)
|
||||
loginRequest.WireguardPort = &wp
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
@@ -223,6 +295,18 @@ func validateNATExternalIPs(list []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseInterfaceName(name string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "utun") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid interface name %s. Please use the prefix utun followed by a number on MacOS. e.g., utun1 or utun199", name)
|
||||
}
|
||||
|
||||
func validateElement(element string) (int, error) {
|
||||
if isValidIP(element) {
|
||||
return ipInputType, nil
|
||||
|
||||
32
client/firewall/create.go
Normal file
32
client/firewall/create.go
Normal file
@@ -0,0 +1,32 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = fm.AllowNetbird()
|
||||
if err != nil {
|
||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
107
client/firewall/create_linux.go
Normal file
107
client/firewall/create_linux.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build !android
|
||||
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN FWType = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
var errFw error
|
||||
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Debug("creating an iptables firewall manager")
|
||||
fm, errFw = nbiptables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||
}
|
||||
case NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager")
|
||||
fm, errFw = nbnftables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||
}
|
||||
default:
|
||||
errFw = fmt.Errorf("no firewall manager found")
|
||||
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
var errUsp error
|
||||
if errFw == nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
if errUsp != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||
return nil, errUsp
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
if errFw != nil {
|
||||
return nil, errFw
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return UNKNOWN
|
||||
}
|
||||
if isIptablesClientAvailable(ip) {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
11
client/firewall/iface.go
Normal file
11
client/firewall/iface.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package firewall
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
473
client/firewall/iptables/acl_linux.go
Normal file
473
client/firewall/iptables/acl_linux.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
|
||||
postRoutingMark = "0x000007e4"
|
||||
)
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
||||
}
|
||||
|
||||
m.seedInitialEntries()
|
||||
|
||||
err = m.cleanChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = m.createDefaultChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
|
||||
var chain string
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
chain = chainNameOutputRules
|
||||
} else {
|
||||
chain = chainNameInputRules
|
||||
}
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
ipList.addIP(ip.String())
|
||||
return []firewall.Rule{&Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
||||
if err != nil {
|
||||
return []firewall.Rule{rule}, err
|
||||
}
|
||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.chain == "PREROUTING" {
|
||||
goto DELETERULE
|
||||
}
|
||||
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ipsetList.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
DELETERULE:
|
||||
var table string
|
||||
if r.chain == "PREROUTING" {
|
||||
table = "mangle"
|
||||
} else {
|
||||
table = "filter"
|
||||
}
|
||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
||||
if err != nil {
|
||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
return m.cleanChains()
|
||||
}
|
||||
|
||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
||||
var src []string
|
||||
if ipsetName != "" {
|
||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
||||
} else {
|
||||
src = []string{"-s", ip.String()}
|
||||
}
|
||||
specs := []string{
|
||||
"-d", m.wgIface.Address().IP.String(),
|
||||
"-p", protocol,
|
||||
"--dport", port,
|
||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
||||
}
|
||||
|
||||
specs = append(src, specs...)
|
||||
|
||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: "PREROUTING",
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
rules := m.entries["OUTPUT"]
|
||||
for _, rule := range rules {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["INPUT"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries["FORWARD"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) createDefaultChains() error {
|
||||
// chain netbird-acl-input-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// chain netbird-acl-output-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if chainName == "FORWARD" {
|
||||
// position 2 because we add it after router's, jump rule
|
||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
||||
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
|
||||
m.appendToEntries("PREROUTING",
|
||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case firewall.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case firewall.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
return append(specs, "-j", actionToStr(action))
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
switch {
|
||||
case ipsetName == "":
|
||||
return ""
|
||||
case sPort != "" && dPort != "":
|
||||
return ipsetName + "-sport-dport"
|
||||
case sPort != "":
|
||||
return ipsetName + "-sport"
|
||||
case dPort != "":
|
||||
return ipsetName + "-dport"
|
||||
default:
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,43 +1,27 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
|
||||
ChainInputFilterName = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
|
||||
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
|
||||
)
|
||||
|
||||
// dropAllDefaultRule in the Netbird chain
|
||||
var dropAllDefaultRule = []string{"-j", "DROP"}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
ipv6Client *iptables.IPTables
|
||||
|
||||
inputDefaultRuleSpecs []string
|
||||
outputDefaultRuleSpecs []string
|
||||
wgIface iFaceMapper
|
||||
|
||||
rulesets map[string]ruleset
|
||||
aclMgr *aclManager
|
||||
router *routerManager
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -47,47 +31,29 @@ type iFaceMapper interface {
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
type ruleset struct {
|
||||
rule *Rule
|
||||
ips map[string]string
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
||||
m := &Manager{
|
||||
wgIface: wgIface,
|
||||
inputDefaultRuleSpecs: []string{
|
||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||
outputDefaultRuleSpecs: []string{
|
||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||
rulesets: make(map[string]ruleset),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
|
||||
if ipv6Supported {
|
||||
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
|
||||
}
|
||||
m := &Manager{
|
||||
wgIface: wgIface,
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
if m.ipv4Client == nil && m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
|
||||
m.router, err = newRouterManager(context, iptablesClient)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize route related chains: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -96,159 +62,44 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
client, err := m.client(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
|
||||
if ipsetName != "" {
|
||||
rs, rsExists := m.rulesets[ipsetName]
|
||||
if !rsExists {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
if rsExists {
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
rs.ips[ip.String()] = ruleID
|
||||
return &Rule{
|
||||
ruleID: ruleID,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
}
|
||||
// this is new ipset so we need to create firewall rule for it
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is output rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("input rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is input rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("input rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: ruleID,
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
if ipsetName != "" {
|
||||
// ipset name is defined and it means that this rule was created
|
||||
// for it, need to associate it with ruleset
|
||||
m.rulesets[ipsetName] = ruleset{
|
||||
rule: rule,
|
||||
ips: map[string]string{rule.ip: ruleID},
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
return m.aclMgr.DeleteRule(rule)
|
||||
}
|
||||
|
||||
client := m.ipv4Client
|
||||
if r.v6 {
|
||||
if m.ipv6Client == nil {
|
||||
return fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
client = m.ipv6Client
|
||||
}
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := rs.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(rs.ips, r.ip)
|
||||
}
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(rs.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
delete(m.rulesets, r.ipsetName)
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
r = rs.rule
|
||||
}
|
||||
|
||||
if r.dst {
|
||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||
}
|
||||
return client.Delete("filter", ChainInputFilterName, r.specs...)
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -256,223 +107,49 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.reset(m.ipv4Client, "filter"); err != nil {
|
||||
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
|
||||
errAcl := m.aclMgr.Reset()
|
||||
if errAcl != nil {
|
||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
||||
}
|
||||
if m.ipv6Client != nil {
|
||||
if err := m.reset(m.ipv6Client, "filter"); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
|
||||
}
|
||||
errMgr := m.router.Reset()
|
||||
if errMgr != nil {
|
||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
||||
return errMgr
|
||||
}
|
||||
|
||||
return nil
|
||||
return errAcl
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if m.wgIface.IsUserspaceBind() {
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
fw.RuleDirectionOUT,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
return err
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionIN,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionOUT,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if input chain exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = client.ChainExists(table, ChainOutputFilterName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if output chain exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
||||
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
|
||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
|
||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for ipsetName := range m.rulesets {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
delete(m.rulesets, ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case fw.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case fw.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
return append(specs, "-j", m.actionToStr(action))
|
||||
}
|
||||
|
||||
// rawClient returns corresponding iptables client for the given ip
|
||||
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
|
||||
if ip.To4() != nil {
|
||||
return m.ipv4Client, nil
|
||||
}
|
||||
if m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
return m.ipv6Client, nil
|
||||
}
|
||||
|
||||
// client returns client with initialized chain and default rules
|
||||
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
||||
client, err := m.rawClient(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := client.ChainExists("filter", ChainInputFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
ok, err = client.ChainExists("filter", ChainOutputFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (m *Manager) actionToStr(action fw.Action) string {
|
||||
if action == fw.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
switch {
|
||||
case ipsetName == "":
|
||||
return ""
|
||||
case sPort != "" && dPort != "":
|
||||
return ipsetName + "-sport-dport"
|
||||
case sPort != "":
|
||||
return ipsetName + "-sport"
|
||||
case dPort != "":
|
||||
return ipsetName + "-dport"
|
||||
default:
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
@@ -55,7 +56,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, true)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -67,17 +68,20 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
@@ -87,21 +91,28 @@ func TestIptablesManager(t *testing.T) {
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||
for _, r := range rule2 {
|
||||
rr := r.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -114,11 +125,11 @@ func TestIptablesManager(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -143,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, true)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -155,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
@@ -165,12 +176,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
@@ -180,23 +193,29 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -206,7 +225,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||
t.Helper()
|
||||
t.Helper()
|
||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||
require.NoError(t, err, "failed to check rule")
|
||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
||||
@@ -232,7 +251,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, true)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -243,7 +262,6 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
_, err = manager.client(net.ParseIP("10.20.0.100"))
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
|
||||
340
client/firewall/iptables/router_linux.go
Normal file
340
client/firewall/iptables/router_linux.go
Normal file
@@ -0,0 +1,340 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
)
|
||||
|
||||
type routerManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
}
|
||||
|
||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
m := &routerManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
}
|
||||
|
||||
err := m.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
err = m.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return m, err
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts an iptable rule
|
||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) RouteingFwChainName() string {
|
||||
return chainRTFWD
|
||||
}
|
||||
|
||||
func (i *routerManager) Reset() error {
|
||||
err := i.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules = make(map[string][]string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createContainers() error {
|
||||
if i.rules[Ipv4Forwarding] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errMSGFormat := "failed creating chain %s,error: %v"
|
||||
err := i.createChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
||||
}
|
||||
|
||||
err = i.createChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addJumpRules create jump rules to send packets to NetBird chains
|
||||
func (i *routerManager) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTFWD}
|
||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[Ipv4Forwarding] = rule
|
||||
|
||||
rule = []string{"-j", chainRTNAT}
|
||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||
func (i *routerManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
||||
rule, found := i.rules[Ipv4Forwarding]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4Nat]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createChain(table, newChain string) error {
|
||||
chains, err := i.iptablesClient.ListChains(table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
||||
}
|
||||
|
||||
shouldCreateChain := true
|
||||
for _, chain := range chains {
|
||||
if chain == newChain {
|
||||
shouldCreateChain = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreateChain {
|
||||
err = i.iptablesClient.NewChain(table, newChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification with comment identifier
|
||||
func genRuleSpec(jump, id, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
||||
}
|
||||
|
||||
func getIptablesRuleType(table string) string {
|
||||
ruleType := "forwarding"
|
||||
if table == tableNat {
|
||||
ruleType = "nat"
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
229
client/firewall/iptables/router_linux_test.go
Normal file
229
client/firewall/iptables/router_linux_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
_, err4 := exec.LookPath("iptables")
|
||||
return err4 == nil
|
||||
}
|
||||
|
||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.100.0/24",
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset iptables manager: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
foundRule, found := manager.rules[forwardRuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "income forwarding rule should exist")
|
||||
|
||||
foundRule, found = manager.rules[inForwardRuleKey]
|
||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
||||
require.True(t, foundNat, "nat rule should exist in the map")
|
||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[natRuleKey]
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "income nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[inNatRuleKey]
|
||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "forwarding rule should not exist")
|
||||
|
||||
_, found := manager.rules[forwardRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "income forwarding rule should not exist")
|
||||
|
||||
_, found = manager.rules[inForwardRuleKey]
|
||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[natRuleKey]
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "income nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,7 @@ type Rule struct {
|
||||
|
||||
specs []string
|
||||
ip string
|
||||
dst bool
|
||||
v6 bool
|
||||
chain string
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
50
client/firewall/iptables/rulestore_linux.go
Normal file
50
client/firewall/iptables/rulestore_linux.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package iptables
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]ipList // ipsetName -> ruleset
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
s.ipsets[ipsetName] = ipList{}
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipsetNames() []string {
|
||||
names := make([]string, 0, len(s.ipsets))
|
||||
for name := range s.ipsets {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -1,9 +1,17 @@
|
||||
package firewall
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
NatFormat = "netbird-nat-%s"
|
||||
ForwardingFormat = "netbird-fwd-%s"
|
||||
InNatFormat = "netbird-nat-in-%s"
|
||||
InForwardingFormat = "netbird-fwd-in-%s"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
@@ -27,10 +35,8 @@ const (
|
||||
type Action int
|
||||
|
||||
const (
|
||||
// ActionUnknown is a unknown action
|
||||
ActionUnknown Action = iota
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept
|
||||
ActionAccept Action = iota
|
||||
// ActionDrop is the action to drop a packet
|
||||
ActionDrop
|
||||
)
|
||||
@@ -56,16 +62,27 @@ type Manager interface {
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
DeleteRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
// InsertRoutingRules inserts a routing firewall rule
|
||||
InsertRoutingRules(pair RouterPair) error
|
||||
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair RouterPair) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// TODO: migrate routemanager firewal actions to this interface
|
||||
}
|
||||
|
||||
func GenKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package firewall
|
||||
package manager
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
18
client/firewall/manager/routerpair.go
Normal file
18
client/firewall/manager/routerpair.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package manager
|
||||
|
||||
type RouterPair struct {
|
||||
ID string
|
||||
Source string
|
||||
Destination string
|
||||
Masquerade bool
|
||||
}
|
||||
|
||||
func GetInPair(pair RouterPair) RouterPair {
|
||||
return RouterPair{
|
||||
ID: pair.ID,
|
||||
// invert Source/Destination
|
||||
Source: pair.Destination,
|
||||
Destination: pair.Source,
|
||||
Masquerade: pair.Masquerade,
|
||||
}
|
||||
}
|
||||
1196
client/firewall/nftables/acl_linux.go
Normal file
1196
client/firewall/nftables/acl_linux.go
Normal file
File diff suppressed because it is too large
Load Diff
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsetReference map[string]int
|
||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsetReference: make(map[string]int),
|
||||
ipsets: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||
s.ipsetReference[ipsetName] = 0
|
||||
ipList := make(map[string]struct{})
|
||||
s.ipsets[ipsetName] = ipList
|
||||
return ipList
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsetReference, ipsetName)
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ipList, ip.String())
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ipList[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = ipList[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||
s.ipsetReference[ipsetName]++
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||
r, ok := s.ipsetReference[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if r == 0 {
|
||||
return
|
||||
}
|
||||
s.ipsetReference[ipsetName]--
|
||||
}
|
||||
|
||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||
return false
|
||||
}
|
||||
if s.ipsetReference[ipsetName] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -2,90 +2,52 @@ package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
// FilterTableName is the name of the table that is used for filtering by the Netbird client
|
||||
FilterTableName = "netbird-acl"
|
||||
|
||||
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
|
||||
FilterInputChainName = "netbird-acl-input-filter"
|
||||
|
||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
|
||||
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
||||
tableName = "netbird"
|
||||
)
|
||||
|
||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
|
||||
filterInputChainIPv4 *nftables.Chain
|
||||
filterOutputChainIPv4 *nftables.Chain
|
||||
|
||||
filterInputChainIPv6 *nftables.Chain
|
||||
filterOutputChainIPv6 *nftables.Chain
|
||||
|
||||
rulesetManager *rulesetManager
|
||||
setRemovedIPs map[string]struct{}
|
||||
setRemoved map[string]*nftables.Set
|
||||
|
||||
mutex sync.Mutex
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
|
||||
rulesetManager: newRuleManager(),
|
||||
setRemovedIPs: map[string]struct{}{},
|
||||
setRemoved: map[string]*nftables.Set{},
|
||||
|
||||
wgIface: wgIface,
|
||||
m.router, err = newRouter(context, workTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -98,649 +60,66 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var (
|
||||
err error
|
||||
ipset *nftables.Set
|
||||
table *nftables.Table
|
||||
chain *nftables.Chain
|
||||
)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
FilterOutputChainName,
|
||||
nftables.ChainHookOutput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
} else {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
FilterInputChainName,
|
||||
nftables.ChainHookInput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
rawIP = ip.To16()
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||
|
||||
if ipsetName != "" {
|
||||
// if we already have set with given name, just add ip to the set
|
||||
// and return rule with new ID in other case let's create rule
|
||||
// with fresh created set and set element
|
||||
|
||||
var isSetNew bool
|
||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
isSetNew = true
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
if !isSetNew {
|
||||
// if we already have nftables rules with set for given direction
|
||||
// just add new rule to the ruleset and return new fw.Rule object
|
||||
|
||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
// if ipset exists but it is not linked to rule for given direction
|
||||
// create new rule for direction and bind ipset to it later
|
||||
}
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
if proto != "all" {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case fw.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case fw.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: protoData,
|
||||
})
|
||||
}
|
||||
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrLen := uint32(len(rawIP))
|
||||
addrOffset := uint32(12)
|
||||
if addrLen == 16 {
|
||||
addrOffset = 8
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
addrOffset += addrLen
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: addrOffset,
|
||||
Len: addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipsetName,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 0,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*sPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*dPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if action == fw.ActionAccept {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
} else {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||
|
||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
|
||||
// getRulesetID returns ruleset ID based on given parameters
|
||||
func (m *Manager) getRulesetID(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipsetName == "" {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipsetName + rulesetID
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *Manager) createSet(
|
||||
table *nftables.Table,
|
||||
rawIP []byte,
|
||||
name string,
|
||||
) (*nftables.Set, error) {
|
||||
keyType := nftables.TypeIPAddr
|
||||
if len(rawIP) == 16 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
// else we create new ipset and continue creating rule
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: keyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// chain returns the chain for the given IP address with specific settings
|
||||
func (m *Manager) chain(
|
||||
ip net.IP,
|
||||
name string,
|
||||
hook nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
cType nftables.ChainType,
|
||||
) (*nftables.Table, *nftables.Chain, error) {
|
||||
var err error
|
||||
|
||||
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if name == FilterInputChainName {
|
||||
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
|
||||
return m.tableIPv4, m.filterInputChainIPv4, err
|
||||
}
|
||||
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
|
||||
return m.tableIPv4, m.filterOutputChainIPv4, err
|
||||
}
|
||||
if name == FilterInputChainName {
|
||||
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
|
||||
return m.tableIPv4, m.filterInputChainIPv6, err
|
||||
}
|
||||
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
|
||||
return m.tableIPv4, m.filterOutputChainIPv6, err
|
||||
}
|
||||
|
||||
// table returns the table for the given family of the IP address
|
||||
func (m *Manager) table(
|
||||
family nftables.TableFamily, tableName string,
|
||||
) (*nftables.Table, error) {
|
||||
// we cache access to Netbird ACL table only
|
||||
if tableName != FilterTableName {
|
||||
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||
}
|
||||
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
if m.tableIPv4 != nil {
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.tableIPv4 = table
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
if m.tableIPv6 != nil {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.tableIPv6 = table
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(
|
||||
family nftables.TableFamily, tableName string,
|
||||
) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
family nftables.TableFamily,
|
||||
tableName string,
|
||||
name string,
|
||||
hooknum nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
chainType nftables.ChainType,
|
||||
) (*nftables.Chain, error) {
|
||||
table, err := m.table(family, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
for _, c := range chains {
|
||||
if c.Name == name && c.Table.Name == table.Name {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Hooknum: hooknum,
|
||||
Priority: priority,
|
||||
Type: chainType,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
shiftDSTAddr := 0
|
||||
if name == FilterOutputChainName {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
shiftDSTAddr = 1
|
||||
}
|
||||
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
|
||||
if m.wgIface.Address().IP.To4() == nil {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(8 + (16 * shiftDSTAddr)),
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 16,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: mask.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
} else {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(12 + (4 * shiftDSTAddr)),
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
expressions = []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
nativeRule, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if nativeRule.nftRule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
// call twice of delete set element raises error
|
||||
// so we need to check if element is already removed
|
||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.setRemovedIPs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if m.rulesetManager.deleteRule(nativeRule) {
|
||||
// deleteRule indicates that we still have IP in the ruleset
|
||||
// it means we should not remove the nftables rule but need to update set
|
||||
// so we prepare IP to be removed from set on the next flush call
|
||||
return nil
|
||||
}
|
||||
|
||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
nativeRule.nftRule = nil
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||
}
|
||||
nativeRule.nftSet = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.aclManager.DeleteRule(rule)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.rConn.DelChain(c)
|
||||
}
|
||||
}
|
||||
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.rConn.Flush()
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *Manager) Flush() error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// set must be removed after flush rule changes
|
||||
// otherwise we will get error
|
||||
for _, s := range m.setRemoved {
|
||||
m.rConn.FlushSet(s)
|
||||
m.rConn.DelSet(s)
|
||||
}
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if len(m.setRemoved) > 0 {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.setRemovedIPs = map[string]struct{}{}
|
||||
m.setRemoved = map[string]*nftables.Set{}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
tf := nftables.TableFamilyIPv4
|
||||
if m.wgIface.Address().IP.To4() == nil {
|
||||
tf = nftables.TableFamilyIPv6
|
||||
err := m.aclManager.createDefaultAllowRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(tf)
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
@@ -774,50 +153,79 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||
if table == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(table, chain)
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) != 0 {
|
||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||
log.Errorf("failed to set rule handle: %v", err)
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
m.router.ResetForwardRules()
|
||||
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
// todo review this method usage
|
||||
func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
@@ -835,7 +243,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(AllowNetbirdInputRuleID),
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
@@ -857,15 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
return bs
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, []byte(n+"\x00"))
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
@@ -36,6 +37,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
@@ -53,7 +56,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
@@ -82,14 +85,10 @@ func TestNftablesManager(t *testing.T) {
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
// test expectations:
|
||||
// 1) regular rule
|
||||
// 2) "accept extra routed traffic rule" for the interface
|
||||
// 3) "drop all rule" for the interface
|
||||
require.Len(t, rules, 3, "expected 3 rules")
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
@@ -137,18 +136,17 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
err = manager.DeleteRule(rule)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule {
|
||||
err = manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// test expectations:
|
||||
// 1) "accept extra routed traffic rule" for the interface
|
||||
// 2) "drop all rule" for the interface
|
||||
require.Len(t, rules, 2, "expected 2 rules after deletion")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
@@ -173,7 +171,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
|
||||
413
client/firewall/nftables/route_linux.go
Normal file
413
client/firewall/nftables/route_linux.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRouteingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (r *router) RouteingFwChainName() string {
|
||||
return chainNameRouteingFw
|
||||
}
|
||||
|
||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||
func (r *router) ResetForwardRules() {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRouteingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
log.Debugf("add default accept forward rule")
|
||||
r.acceptForwardRule(pair.Source)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
var expression []expr.Any
|
||||
if isNat {
|
||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
||||
} else {
|
||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
||||
}
|
||||
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
_, exists := r.rules[ruleKey]
|
||||
if exists {
|
||||
err := r.removeRoutingRule(format, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||
}
|
||||
|
||||
r.conn.AddRule(rule)
|
||||
|
||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule = &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
r.conn.AddRule(rule)
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.rules) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
rule, found := r.rules[ruleKey]
|
||||
if found {
|
||||
ruleType := "forwarding"
|
||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||
ruleType = "nat"
|
||||
}
|
||||
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rules []*nftables.Rule
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
if chain.Name != "FORWARD" {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
ip, network, _ := net.ParseCIDR(cidr)
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
var offSet uint32
|
||||
if source {
|
||||
offSet = 12 // src offset
|
||||
} else {
|
||||
offSet = 16 // dst offset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: 4,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
280
client/firewall/nftables/router_linux_test.go
Normal file
280
client/firewall/nftables/router_linux_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found = 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
})
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(inForwardRuleKey),
|
||||
})
|
||||
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(inNatRuleKey),
|
||||
})
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
manager.ResetForwardRules()
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() int {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return UNKNOWN
|
||||
}
|
||||
if isIptablesClientAvailable(ip) {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func createWorkTable() (*nftables.Table, error) {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = sConn.Flush()
|
||||
|
||||
return table, err
|
||||
}
|
||||
|
||||
func deleteWorkTable() {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
@@ -8,9 +10,8 @@ import (
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
|
||||
ruleID string
|
||||
ip []byte
|
||||
ruleID string
|
||||
ip net.IP
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||
type nftRuleset struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
issuedRules map[string]*Rule
|
||||
rulesetID string
|
||||
}
|
||||
|
||||
type rulesetManager struct {
|
||||
rulesets map[string]*nftRuleset
|
||||
|
||||
nftSetName2rulesetID map[string]string
|
||||
issuedRuleID2rulesetID map[string]string
|
||||
}
|
||||
|
||||
func newRuleManager() *rulesetManager {
|
||||
return &rulesetManager{
|
||||
rulesets: map[string]*nftRuleset{},
|
||||
|
||||
nftSetName2rulesetID: map[string]string{},
|
||||
issuedRuleID2rulesetID: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||
ruleset, ok := r.rulesets[rulesetID]
|
||||
return ruleset, ok
|
||||
}
|
||||
|
||||
func (r *rulesetManager) createRuleset(
|
||||
rulesetID string,
|
||||
nftRule *nftables.Rule,
|
||||
nftSet *nftables.Set,
|
||||
) *nftRuleset {
|
||||
ruleset := nftRuleset{
|
||||
rulesetID: rulesetID,
|
||||
nftRule: nftRule,
|
||||
nftSet: nftSet,
|
||||
issuedRules: map[string]*Rule{},
|
||||
}
|
||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||
if nftSet != nil {
|
||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||
}
|
||||
return &ruleset
|
||||
}
|
||||
|
||||
func (r *rulesetManager) addRule(
|
||||
ruleset *nftRuleset,
|
||||
ip []byte,
|
||||
) (*Rule, error) {
|
||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||
return nil, fmt.Errorf("ruleset not found")
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
nftRule: ruleset.nftRule,
|
||||
nftSet: ruleset.nftSet,
|
||||
ruleID: xid.New().String(),
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
ruleset.issuedRules[rule.ruleID] = &rule
|
||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// deleteRule from ruleset and returns true if contains other rules
|
||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ruleset := r.rulesets[rulesetID]
|
||||
if ruleset.nftRule == nil {
|
||||
return false
|
||||
}
|
||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||
delete(ruleset.issuedRules, rule.ruleID)
|
||||
|
||||
if len(ruleset.issuedRules) == 0 {
|
||||
delete(r.rulesets, ruleset.rulesetID)
|
||||
if rule.nftSet != nil {
|
||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||
//
|
||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||
// we set correct handle value to it.
|
||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||
ruleset, ok := r.rulesets[string(split[0])]
|
||||
if !ok {
|
||||
return fmt.Errorf("ruleset not found")
|
||||
}
|
||||
*ruleset.nftRule = *nftRule
|
||||
return nil
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{
|
||||
UserData: []byte(rulesetID),
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_addRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||
|
||||
ruleset2 := &nftRuleset{
|
||||
rulesetID: "ruleset-2",
|
||||
}
|
||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||
require.Error(t, err, "addRule() should have failed")
|
||||
}
|
||||
|
||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
ip2 := []byte("192.168.1.1")
|
||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule2, "rule should not be nil")
|
||||
|
||||
hasNext := rulesetManager.deleteRule(rule)
|
||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||
|
||||
// Check that the rule is no longer in the manager.
|
||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||
|
||||
hasNext = rulesetManager.deleteRule(rule2)
|
||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||
}
|
||||
|
||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.0.1")
|
||||
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
nftRuleCopy := nftRule
|
||||
nftRuleCopy.Handle = 2
|
||||
nftRuleCopy.UserData = []byte(rulesetID)
|
||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||
// check correct work with references
|
||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
nftSet := nftables.Set{
|
||||
ID: 2,
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
|
||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||
require.True(t, ok, "getRuleset() failed")
|
||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||
|
||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||
require.False(t, ok, "getRuleset() failed")
|
||||
}
|
||||
47
client/firewall/test/cases_linux.go
Normal file
47
client/firewall/test/cases_linux.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build !android
|
||||
|
||||
package test
|
||||
|
||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
var (
|
||||
InsertRuleTestCases = []struct {
|
||||
Name string
|
||||
InputPair firewall.RouterPair
|
||||
}{
|
||||
{
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RemoveRuleTestCases = []struct {
|
||||
Name string
|
||||
InputPair firewall.RouterPair
|
||||
IpVersion string
|
||||
}{
|
||||
{
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !windows && !linux
|
||||
//go:build !windows
|
||||
|
||||
package uspfilter
|
||||
|
||||
@@ -10,10 +10,16 @@ func (m *Manager) Reset() error {
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.resetHook != nil {
|
||||
return m.resetHook()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
@@ -15,7 +15,7 @@ type Rule struct {
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction fw.RuleDirection
|
||||
direction firewall.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
drop bool
|
||||
|
||||
@@ -10,12 +10,16 @@ import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
@@ -27,12 +31,12 @@ type RuleSet map[string]Rule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
resetHook func() error
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
@@ -52,6 +56,20 @@ type decoder struct {
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr.nativeFirewall = nativeFirewall
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
@@ -77,27 +95,50 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == fw.ActionDrop,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
@@ -118,21 +159,21 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
case firewall.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case fw.ProtocolUDP:
|
||||
case firewall.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case fw.ProtocolICMP:
|
||||
case firewall.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case fw.ProtocolALL:
|
||||
case firewall.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if direction == fw.RuleDirectionIN {
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
@@ -144,12 +185,11 @@ func (m *Manager) AddFiltering(
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &r, nil
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -158,7 +198,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if r.direction == fw.RuleDirectionIN {
|
||||
if r.direction == firewall.RuleDirectionIN {
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
@@ -322,7 +362,7 @@ func (m *Manager) AddUDPPacketHook(
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: dPort,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: fw.RuleDirectionOUT,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
@@ -333,7 +373,7 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
r.direction = fw.RuleDirectionIN
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
@@ -370,8 +410,3 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetResetHook which will be executed in the end of Reset method
|
||||
func (m *Manager) SetResetHook(hook func() error) {
|
||||
m.resetHook = hook
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
@@ -125,24 +125,32 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
for _, r := range rule {
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
for _, r := range rule2 {
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ Sleep 3000
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
SetShellVarContext all
|
||||
|
||||
@@ -11,42 +11,27 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
Stop()
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
manager firewall.Manager
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type ipsetInfo struct {
|
||||
name string
|
||||
ipCount int
|
||||
}
|
||||
|
||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
firewall: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
@@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
if d.manager == nil {
|
||||
if d.firewall == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.manager.Flush(); err != nil {
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
@@ -125,57 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
)
|
||||
}
|
||||
|
||||
applyFailed := false
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||
|
||||
// calculate which IP's can be grouped in by which ipset
|
||||
// to do that we use rule selector (which is just rule properties without IP's)
|
||||
for _, r := range rules {
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipset, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
ipset = &ipsetInfo{}
|
||||
}
|
||||
|
||||
ipset.ipCount++
|
||||
ipsetByRuleSelectors[selector] = ipset
|
||||
}
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||
if ipset.name == "" {
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
d.ipsetCounter++
|
||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetByRuleSelectors[selector] = ipsetName
|
||||
}
|
||||
ipsetName := ipset.name
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
applyFailed = true
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
}
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
if applyFailed {
|
||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.manager.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
d.rulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for pairID, rules := range d.rulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.manager.DeleteRule(rule); err != nil {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -186,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
d.rulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
// Stop ACL controller and clear firewall state
|
||||
func (d *DefaultManager) Stop() {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
if err := d.manager.Reset(); err != nil {
|
||||
log.WithError(err).Error("reset firewall state")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
@@ -205,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
|
||||
protocol := convertToFirewallProtocol(r.Protocol)
|
||||
if protocol == firewall.ProtocolUnknown {
|
||||
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
action := convertFirewallAction(r.Action)
|
||||
if action == firewall.ActionUnknown {
|
||||
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||
action, err := convertFirewallAction(r.Action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
@@ -232,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
var err error
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
@@ -246,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
d.rulesPairs[ruleID] = rules
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
@@ -259,24 +210,24 @@ func (d *DefaultManager) addInRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule), nil
|
||||
return append(rules, rule...), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
@@ -288,24 +239,24 @@ func (d *DefaultManager) addOutRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule), nil
|
||||
return append(rules, rule...), nil
|
||||
}
|
||||
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
@@ -461,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
return firewall.ProtocolTCP
|
||||
return firewall.ProtocolTCP, nil
|
||||
case mgmProto.FirewallRule_UDP:
|
||||
return firewall.ProtocolUDP
|
||||
return firewall.ProtocolUDP, nil
|
||||
case mgmProto.FirewallRule_ICMP:
|
||||
return firewall.ProtocolICMP
|
||||
return firewall.ProtocolICMP, nil
|
||||
case mgmProto.FirewallRule_ALL:
|
||||
return firewall.ProtocolALL
|
||||
return firewall.ProtocolALL, nil
|
||||
default:
|
||||
return firewall.ProtocolUnknown
|
||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
|
||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||
}
|
||||
|
||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
|
||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
||||
switch action {
|
||||
case mgmProto.FirewallRule_ACCEPT:
|
||||
return firewall.ActionAccept
|
||||
return firewall.ActionAccept, nil
|
||||
case mgmProto.FirewallRule_DROP:
|
||||
return firewall.ActionDrop
|
||||
return firewall.ActionDrop, nil
|
||||
default:
|
||||
return firewall.ActionUnknown
|
||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package acl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
// Create creates a firewall manager instance
|
||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package acl
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||
)
|
||||
|
||||
// Create creates a firewall manager instance for the Linux
|
||||
func Create(iface IFaceMapper) (*DefaultManager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
var err error
|
||||
|
||||
checkResult := checkfw.Check()
|
||||
switch checkResult {
|
||||
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||
log.Debug("creating an iptables firewall manager for access control")
|
||||
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
|
||||
log.Infof("failed to create iptables manager for access control: %s", err)
|
||||
}
|
||||
case checkfw.NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager for access control")
|
||||
if fm, err = nftables.Create(iface); err != nil {
|
||||
log.Debugf("failed to create nftables manager for access control: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var resetHookForUserspace func() error
|
||||
if fm != nil && err == nil {
|
||||
// err shadowing is used here, to ignore this error
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
resetHookForUserspace = fm.Reset
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
usfm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set kernel space firewall Reset as hook for userspace firewall
|
||||
// manager Reset method, to clean up
|
||||
if resetHookForUserspace != nil {
|
||||
usfm.SetResetHook(resetHookForUserspace)
|
||||
}
|
||||
|
||||
// to be consistent for any future extensions.
|
||||
// ignore this error
|
||||
if err := usfm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
fm = usfm
|
||||
}
|
||||
|
||||
if fm == nil || err != nil {
|
||||
log.Errorf("failed to create firewall manager: %s", err)
|
||||
// no firewall manager found or initialized correctly
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
@@ -35,7 +38,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
@@ -49,12 +52,15 @@ func TestDefaultManager(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(ifaceMock)
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer acl.Stop()
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap)
|
||||
@@ -325,7 +331,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
@@ -339,12 +345,15 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(ifaceMock)
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer acl.Stop()
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ type HTTPClient interface {
|
||||
}
|
||||
|
||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||
type AuthFlowInfo struct {
|
||||
type AuthFlowInfo struct { //nolint:revive
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package checkfw
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package checkfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN FWType = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
|
||||
IPTABLESWITHV6
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func Check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err == nil {
|
||||
if isIptablesClientAvailable(ip) {
|
||||
ipSupport := IPTABLES
|
||||
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if ip6Err == nil {
|
||||
if isIptablesClientAvailable(ipv6) {
|
||||
ipSupport = IPTABLESWITHV6
|
||||
}
|
||||
}
|
||||
return ipSupport
|
||||
}
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -12,16 +13,19 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// managementLegacyPortString is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
|
||||
ManagementLegacyPort = 33073
|
||||
managementLegacyPortString = "33073"
|
||||
// DefaultManagementURL points to the NetBird's cloud management endpoint
|
||||
DefaultManagementURL = "https://api.wiretrustee.com:443"
|
||||
DefaultManagementURL = "https://api.netbird.io:443"
|
||||
// oldDefaultManagementURL points to the NetBird's old cloud management endpoint
|
||||
oldDefaultManagementURL = "https://api.wiretrustee.com:443"
|
||||
// DefaultAdminURL points to NetBird's cloud management console
|
||||
DefaultAdminURL = "https://app.netbird.io:443"
|
||||
)
|
||||
@@ -31,12 +35,18 @@ var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun",
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
PreSharedKey *string
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
DisableAutoConnect *bool
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -50,10 +60,13 @@ type Config struct {
|
||||
WgPort int
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
// ExternalIP mappings, if different than the host interface IP
|
||||
// ExternalIP mappings, if different from the host interface IP
|
||||
//
|
||||
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
|
||||
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
|
||||
@@ -71,6 +84,10 @@ type Config struct {
|
||||
NATExternalIPs []string
|
||||
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
|
||||
CustomDNSAddress string
|
||||
|
||||
// DisableAutoConnect determines whether the client should not start with the service
|
||||
// it's set to false by default due to backwards compatibility
|
||||
DisableAutoConnect bool
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@@ -80,6 +97,7 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -136,15 +154,16 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
SSHKey: string(pem),
|
||||
PrivateKey: wgKey,
|
||||
WgIface: iface.WgInterfaceDefault,
|
||||
WgPort: iface.DefaultWgPort,
|
||||
IFaceBlackList: []string{},
|
||||
DisableIPv6Discovery: false,
|
||||
NATExternalIPs: input.NATExternalIPs,
|
||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
||||
ServerSSHAllowed: util.False(),
|
||||
DisableAutoConnect: false,
|
||||
}
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
@@ -161,10 +180,32 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config.ManagementURL = URL
|
||||
}
|
||||
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
if input.WireguardPort != nil {
|
||||
config.WgPort = *input.WireguardPort
|
||||
}
|
||||
|
||||
config.WgIface = iface.WgInterfaceDefault
|
||||
if input.InterfaceName != nil {
|
||||
config.WgIface = *input.InterfaceName
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil {
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil {
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil {
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
}
|
||||
|
||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -215,12 +256,9 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
if *input.PreSharedKey != "" {
|
||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
log.Infof("new pre-shared key provided, replacing old key")
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
@@ -236,6 +274,17 @@ func update(input ConfigInput) (*Config, error) {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.WireguardPort != nil {
|
||||
config.WgPort = *input.WireguardPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil {
|
||||
config.WgIface = *input.InterfaceName
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
||||
config.NATExternalIPs = input.NATExternalIPs
|
||||
refresh = true
|
||||
@@ -246,6 +295,31 @@ func update(input ConfigInput) (*Config, error) {
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil {
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil {
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.DisableAutoConnect != nil {
|
||||
config.DisableAutoConnect = *input.DisableAutoConnect
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.ServerSSHAllowed == nil {
|
||||
config.ServerSSHAllowed = util.True()
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if refresh {
|
||||
// since we have new management URL, we need to update config file
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
@@ -305,3 +379,86 @@ func configFileIsExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
@@ -60,22 +63,7 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 4: new empty pre-shared key config -> fetch it
|
||||
newPreSharedKey := ""
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &newPreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 5: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
@@ -134,3 +122,60 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
previousManagementURL string
|
||||
expectedManagementURL string
|
||||
fileShouldNotChange bool
|
||||
}{
|
||||
{
|
||||
name: "Update old management URL with legacy port",
|
||||
previousManagementURL: "https://api.wiretrustee.com:33073",
|
||||
expectedManagementURL: DefaultManagementURL,
|
||||
},
|
||||
{
|
||||
name: "Update old management URL",
|
||||
previousManagementURL: oldDefaultManagementURL,
|
||||
expectedManagementURL: DefaultManagementURL,
|
||||
},
|
||||
{
|
||||
name: "No update needed when management URL is up to date",
|
||||
previousManagementURL: DefaultManagementURL,
|
||||
expectedManagementURL: DefaultManagementURL,
|
||||
fileShouldNotChange: true,
|
||||
},
|
||||
{
|
||||
name: "No update needed when not using cloud management",
|
||||
previousManagementURL: "https://netbird.example.com:33073",
|
||||
expectedManagementURL: "https://netbird.example.com:33073",
|
||||
fileShouldNotChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: tt.previousManagementURL,
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err, "failed to create testing config")
|
||||
previousStats, err := os.Stat(configPath)
|
||||
require.NoError(t, err, "failed to create testing config stats")
|
||||
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
|
||||
require.NoError(t, err, "got error when updating old management url")
|
||||
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
|
||||
newStats, err := os.Stat(configPath)
|
||||
require.NoError(t, err, "failed to create testing config stats")
|
||||
switch tt.fileShouldNotChange {
|
||||
case true:
|
||||
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
|
||||
case false:
|
||||
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,16 +23,39 @@ import (
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// RunClientWithProbes runs the client's main logic with probes attached
|
||||
func RunClientWithProbes(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||
func RunClientMobile(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
tunAdapter iface.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []string,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
@@ -40,12 +64,43 @@ func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.S
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||
func RunClientiOS(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
mobileDependency MobileDependency,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
||||
|
||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
|
||||
log.Errorf("checking unclean shutdown error: %s", err)
|
||||
}
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -94,7 +149,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
engineCtx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
statusRecorder.MarkManagementDisconnected(state.err)
|
||||
statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
@@ -143,8 +198,10 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
statusRecorder.UpdateSignalAddress(signalURL)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected()
|
||||
defer statusRecorder.MarkSignalDisconnected()
|
||||
statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
statusRecorder.MarkSignalDisconnected(state.err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
@@ -172,7 +229,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
@@ -195,7 +252,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
log.Info("stopped NetBird client")
|
||||
|
||||
if _, err := state.Status(); err == ErrResetConnection {
|
||||
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -226,6 +283,9 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -274,83 +334,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
var sri interface{} = statusRecorder
|
||||
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
||||
|
||||
@@ -4,9 +4,11 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"time"
|
||||
)
|
||||
|
||||
const dbusDefaultFlag = 0
|
||||
@@ -14,6 +16,7 @@ const dbusDefaultFlag = 0
|
||||
func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
||||
obj, closeConn, err := getDbusObject(dest, path)
|
||||
if err != nil {
|
||||
log.Tracef("error getting dbus object: %s", err)
|
||||
return false
|
||||
}
|
||||
defer closeConn()
|
||||
@@ -21,14 +24,18 @@ func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store()
|
||||
return err == nil
|
||||
if err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
log.Tracef("error calling dbus: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("get dbus: %w", err)
|
||||
}
|
||||
obj := conn.Object(dest, path)
|
||||
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -23,30 +24,40 @@ const (
|
||||
fileMaxNumberOfSearchDomains = 6
|
||||
)
|
||||
|
||||
const (
|
||||
dnsFailoverTimeout = 4 * time.Second
|
||||
dnsFailoverAttempts = 1
|
||||
)
|
||||
|
||||
type fileConfigurator struct {
|
||||
originalPerms os.FileMode
|
||||
repair *repair
|
||||
|
||||
originalPerms os.FileMode
|
||||
nbNameserverIP string
|
||||
}
|
||||
|
||||
func newFileConfigurator() (hostManager, error) {
|
||||
return &fileConfigurator{}, nil
|
||||
fc := &fileConfigurator{}
|
||||
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
backupFileExist := false
|
||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||
if err == nil {
|
||||
backupFileExist = true
|
||||
}
|
||||
|
||||
if !config.routeAll {
|
||||
if !config.RouteAll {
|
||||
if backupFileExist {
|
||||
err = f.restore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err)
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
@@ -55,66 +66,150 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
if !backupFileExist {
|
||||
err = f.backup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to backup the resolv.conf file")
|
||||
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
searchDomainList := searchDomains(config)
|
||||
nbSearchDomains := searchDomains(config)
|
||||
f.nbNameserverIP = config.ServerIP
|
||||
|
||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
|
||||
resolvConf, err := parseBackupResolvConf()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
|
||||
f.repair.stopWatchFileChanges()
|
||||
|
||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
||||
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||
|
||||
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
append([]string{config.serverIP}, nameServers...),
|
||||
others)
|
||||
nameServers,
|
||||
options)
|
||||
|
||||
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
||||
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
||||
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
||||
if err != nil {
|
||||
restoreErr := f.restore()
|
||||
if restoreErr != nil {
|
||||
log.Errorf("attempt to restore default file failed with error: %s", err)
|
||||
}
|
||||
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
|
||||
return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
|
||||
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||
|
||||
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
||||
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restoreHostDNS() error {
|
||||
f.repair.stopWatchFileChanges()
|
||||
return f.restore()
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) backup() error {
|
||||
stats, err := os.Stat(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err)
|
||||
return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
|
||||
f.originalPerms = stats.Mode()
|
||||
|
||||
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err)
|
||||
return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restore() error {
|
||||
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||
resolvConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse current resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
// no current nameservers set -> restore
|
||||
if len(resolvConf.nameServers) == 0 {
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
||||
// not a valid first nameserver -> restore
|
||||
if err != nil {
|
||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
// current address is still netbird's non-available dns address -> restore
|
||||
// comparing parsed addresses only, to remove ambiguity
|
||||
if currentDNSAddress.String() == storedDNSAddress.String() {
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreResolvConfFile() error {
|
||||
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
||||
|
||||
if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
|
||||
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
|
||||
ns := make([]string, 1, len(cfg.nameServers)+1)
|
||||
ns[0] = nbNameserverIP
|
||||
for _, cfgNs := range cfg.nameServers {
|
||||
if nbNameserverIP != cfgNs {
|
||||
ns = append(ns, cfgNs)
|
||||
}
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
||||
@@ -138,83 +233,19 @@ func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes
|
||||
return buf
|
||||
}
|
||||
|
||||
func searchDomains(config hostDNSConfig) []string {
|
||||
func searchDomains(config HostDNSConfig) []string {
|
||||
listOfDomains := make([]string, 0)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.MatchOnly || dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, dConf.domain)
|
||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
||||
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
|
||||
file, err := os.Open(resolvconfFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf(`could not read existing resolv.conf`)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
reader := bufio.NewReader(file)
|
||||
|
||||
for {
|
||||
lineBytes, isPrefix, readErr := reader.ReadLine()
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if isPrefix {
|
||||
err = fmt.Errorf(`resolv.conf line too long`)
|
||||
return
|
||||
}
|
||||
|
||||
line := strings.TrimSpace(string(lineBytes))
|
||||
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "domain") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
||||
line = strings.ReplaceAll(line, "rotate", "")
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) == 1 {
|
||||
continue
|
||||
}
|
||||
line = strings.Join(splitLines, " ")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "search") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
searchDomains = splitLines[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) != 2 {
|
||||
continue
|
||||
}
|
||||
nameServers = append(nameServers, splitLines[1])
|
||||
continue
|
||||
}
|
||||
|
||||
others = append(others, line)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// merge search domains lists and cut off the list if it is too long
|
||||
// merge search Domains lists and cut off the list if it is too long
|
||||
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
||||
lineSize := len("search")
|
||||
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
|
||||
@@ -225,14 +256,27 @@ func mergeSearchDomains(searchDomains []string, originalSearchDomains []string)
|
||||
return searchDomainsList
|
||||
}
|
||||
|
||||
// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long
|
||||
// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
|
||||
// extend s slice with vs elements
|
||||
// return with the number of characters in the searchDomains line
|
||||
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
||||
for _, sd := range vs {
|
||||
duplicated := false
|
||||
for _, fs := range *s {
|
||||
if fs == sd {
|
||||
duplicated = true
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if duplicated {
|
||||
continue
|
||||
}
|
||||
|
||||
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
||||
if tmpCharsNumber > fileMaxLineCharsLimit {
|
||||
// lets log all skipped domains
|
||||
// lets log all skipped Domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
|
||||
continue
|
||||
}
|
||||
@@ -240,29 +284,45 @@ func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string
|
||||
initialLineChars = tmpCharsNumber
|
||||
|
||||
if len(*s) >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped domains
|
||||
// lets log all skipped Domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
|
||||
continue
|
||||
}
|
||||
*s = append(*s, sd)
|
||||
}
|
||||
|
||||
return initialLineChars
|
||||
}
|
||||
|
||||
func copyFile(src, dest string) error {
|
||||
stats, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err)
|
||||
return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
|
||||
}
|
||||
|
||||
bytesRead, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err)
|
||||
return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(dest, bytesRead, stats.Mode())
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err)
|
||||
return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isContains(subList []string, list []string) bool {
|
||||
for _, sl := range subList {
|
||||
var found bool
|
||||
for _, l := range list {
|
||||
if sl == l {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
@@ -7,7 +9,7 @@ import (
|
||||
|
||||
func Test_mergeSearchDomains(t *testing.T) {
|
||||
searchDomains := []string{"a", "b"}
|
||||
originDomains := []string{"a", "b"}
|
||||
originDomains := []string{"c", "d"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 4 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
|
||||
@@ -49,6 +51,67 @@ func Test_mergeSearchTooLongDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isContains(t *testing.T) {
|
||||
type args struct {
|
||||
subList []string
|
||||
list []string
|
||||
}
|
||||
tests := []struct {
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"a", "b", "c"},
|
||||
list: []string{"a", "b", "c"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"a"},
|
||||
list: []string{"a", "b", "c"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"d"},
|
||||
list: []string{"a", "b", "c"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"a"},
|
||||
list: []string{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{},
|
||||
list: []string{"b"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{},
|
||||
list: []string{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("list check test", func(t *testing.T) {
|
||||
if got := isContains(tt.args.subList, tt.args.list); got != tt.want {
|
||||
t.Errorf("isContains() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getLongLine() string {
|
||||
x := "search "
|
||||
for {
|
||||
|
||||
168
client/internal/dns/file_parser_linux.go
Normal file
168
client/internal/dns/file_parser_linux.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
)
|
||||
|
||||
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
|
||||
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
|
||||
|
||||
type resolvConf struct {
|
||||
nameServers []string
|
||||
searchDomains []string
|
||||
others []string
|
||||
}
|
||||
|
||||
func (r *resolvConf) String() string {
|
||||
return fmt.Sprintf("search domains: %v, name servers: %v, others: %s", r.searchDomains, r.nameServers, r.others)
|
||||
}
|
||||
|
||||
func parseDefaultResolvConf() (*resolvConf, error) {
|
||||
return parseResolvConfFile(defaultResolvConfPath)
|
||||
}
|
||||
|
||||
func parseBackupResolvConf() (*resolvConf, error) {
|
||||
return parseResolvConfFile(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||
rconf := &resolvConf{
|
||||
searchDomains: make([]string, 0),
|
||||
nameServers: make([]string, 0),
|
||||
others: make([]string, 0),
|
||||
}
|
||||
|
||||
file, err := os.Open(resolvConfFile)
|
||||
if err != nil {
|
||||
return rconf, fmt.Errorf("failed to open %s file: %w", resolvConfFile, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("failed closing %s: %s", resolvConfFile, err)
|
||||
}
|
||||
}()
|
||||
|
||||
cur, err := os.ReadFile(resolvConfFile)
|
||||
if err != nil {
|
||||
return rconf, fmt.Errorf("failed to read %s file: %w", resolvConfFile, err)
|
||||
}
|
||||
|
||||
if len(cur) == 0 {
|
||||
return rconf, fmt.Errorf("file is empty")
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(string(cur), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "domain") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
||||
line = strings.ReplaceAll(line, "rotate", "")
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) == 1 {
|
||||
continue
|
||||
}
|
||||
line = strings.Join(splitLines, " ")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "search") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
rconf.searchDomains = splitLines[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) != 2 {
|
||||
continue
|
||||
}
|
||||
rconf.nameServers = append(rconf.nameServers, splitLines[1])
|
||||
continue
|
||||
}
|
||||
|
||||
if line != "" {
|
||||
rconf.others = append(rconf.others, line)
|
||||
}
|
||||
}
|
||||
return rconf, nil
|
||||
}
|
||||
|
||||
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
|
||||
// otherwise it adds a new option with timeout and attempts.
|
||||
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
|
||||
configs := make([]string, len(input))
|
||||
copy(configs, input)
|
||||
|
||||
for i, config := range configs {
|
||||
if strings.HasPrefix(config, "options") {
|
||||
config = strings.ReplaceAll(config, "rotate", "")
|
||||
config = strings.Join(strings.Fields(config), " ")
|
||||
|
||||
if strings.Contains(config, "timeout:") {
|
||||
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
|
||||
} else {
|
||||
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
|
||||
}
|
||||
|
||||
if strings.Contains(config, "attempts:") {
|
||||
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
|
||||
} else {
|
||||
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
|
||||
}
|
||||
|
||||
configs[i] = config
|
||||
return configs
|
||||
}
|
||||
}
|
||||
|
||||
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
|
||||
}
|
||||
|
||||
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
||||
// and writes the file back to the original location
|
||||
func removeFirstNbNameserver(filename, nameserverIP string) error {
|
||||
resolvConf, err := parseResolvConfFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
||||
}
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", filename, err)
|
||||
}
|
||||
|
||||
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
|
||||
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
||||
|
||||
stat, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", filename, err)
|
||||
}
|
||||
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
||||
return fmt.Errorf("write %s: %w", filename, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
304
client/internal/dns/file_parser_linux_test.go
Normal file
304
client/internal/dns/file_parser_linux_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseResolvConf(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedSearch []string
|
||||
expectedNS []string
|
||||
expectedOther []string
|
||||
}{
|
||||
{
|
||||
input: `domain example.org
|
||||
search example.org
|
||||
nameserver 192.168.0.1
|
||||
`,
|
||||
expectedSearch: []string{"example.org"},
|
||||
expectedNS: []string{"192.168.0.1"},
|
||||
expectedOther: []string{},
|
||||
},
|
||||
{
|
||||
input: `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||
# Do not edit.
|
||||
#
|
||||
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||
#
|
||||
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||
# all known uplink DNS servers. This file lists all configured search domains.
|
||||
#
|
||||
# Third party programs should typically not access this file directly, but only
|
||||
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||
# different way, replace this symlink by a static file or a different symlink.
|
||||
#
|
||||
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||
# operation for /etc/resolv.conf.
|
||||
|
||||
nameserver 192.168.2.1
|
||||
nameserver 100.81.99.197
|
||||
search netbird.cloud
|
||||
`,
|
||||
expectedSearch: []string{"netbird.cloud"},
|
||||
expectedNS: []string{"192.168.2.1", "100.81.99.197"},
|
||||
expectedOther: []string{},
|
||||
},
|
||||
{
|
||||
input: `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||
# Do not edit.
|
||||
#
|
||||
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||
#
|
||||
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||
# all known uplink DNS servers. This file lists all configured search domains.
|
||||
#
|
||||
# Third party programs should typically not access this file directly, but only
|
||||
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||
# different way, replace this symlink by a static file or a different symlink.
|
||||
#
|
||||
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||
# operation for /etc/resolv.conf.
|
||||
|
||||
nameserver 192.168.2.1
|
||||
nameserver 100.81.99.197
|
||||
search netbird.cloud
|
||||
options debug
|
||||
`,
|
||||
expectedSearch: []string{"netbird.cloud"},
|
||||
expectedNS: []string{"192.168.2.1", "100.81.99.197"},
|
||||
expectedOther: []string{"options debug"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run("test", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
err := os.WriteFile(tmpResolvConf, []byte(testCase.input), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg, err := parseResolvConfFile(tmpResolvConf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ok := compareLists(cfg.searchDomains, testCase.expectedSearch)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
||||
}
|
||||
|
||||
ok = compareLists(cfg.nameServers, testCase.expectedNS)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
|
||||
}
|
||||
|
||||
ok = compareLists(cfg.others, testCase.expectedOther)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for others, expected: %v, got: %v", testCase.expectedOther, cfg.others)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func compareLists(search []string, search2 []string) bool {
|
||||
if len(search) != len(search2) {
|
||||
return false
|
||||
}
|
||||
for i, v := range search {
|
||||
if v != search2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Test_emptyFile(t *testing.T) {
|
||||
cfg, err := parseResolvConfFile("/tmp/nothing")
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
if len(cfg.others) != 0 || len(cfg.searchDomains) != 0 || len(cfg.nameServers) != 0 {
|
||||
t.Errorf("expected empty config, got %v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_symlink(t *testing.T) {
|
||||
input := `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||
# Do not edit.
|
||||
#
|
||||
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||
#
|
||||
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||
# all known uplink DNS servers. This file lists all configured search domains.
|
||||
#
|
||||
# Third party programs should typically not access this file directly, but only
|
||||
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||
# different way, replace this symlink by a static file or a different symlink.
|
||||
#
|
||||
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||
# operation for /etc/resolv.conf.
|
||||
|
||||
nameserver 192.168.0.1
|
||||
`
|
||||
|
||||
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
err := os.WriteFile(tmpResolvConf, []byte(input), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpLink := filepath.Join(t.TempDir(), "symlink")
|
||||
err = os.Symlink(tmpResolvConf, tmpLink)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := parseResolvConfFile(tmpLink)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(cfg.nameServers) != 1 {
|
||||
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareOptionsWithTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
others []string
|
||||
timeout int
|
||||
attempts int
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Append new options with timeout and attempts",
|
||||
others: []string{"some config"},
|
||||
timeout: 2,
|
||||
attempts: 2,
|
||||
expected: []string{"some config", "options timeout:2 attempts:2"},
|
||||
},
|
||||
{
|
||||
name: "Modify existing options to exclude rotate and include timeout and attempts",
|
||||
others: []string{"some config", "options rotate someother"},
|
||||
timeout: 3,
|
||||
attempts: 2,
|
||||
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
|
||||
},
|
||||
{
|
||||
name: "Existing options with timeout and attempts are updated",
|
||||
others: []string{"some config", "options timeout:4 attempts:3"},
|
||||
timeout: 5,
|
||||
attempts: 4,
|
||||
expected: []string{"some config", "options timeout:5 attempts:4"},
|
||||
},
|
||||
{
|
||||
name: "Modify existing options, add missing attempts before timeout",
|
||||
others: []string{"some config", "options timeout:4"},
|
||||
timeout: 4,
|
||||
attempts: 3,
|
||||
expected: []string{"some config", "options attempts:3 timeout:4"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirstNbNameserver(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
content string
|
||||
ipToRemove string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Unrelated nameservers with comments and options",
|
||||
content: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
},
|
||||
{
|
||||
name: "First nameserver matches",
|
||||
content: `search example.com
|
||||
nameserver 9.9.9.9
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `search example.com
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
},
|
||||
{
|
||||
name: "Target IP not the first nameserver",
|
||||
// nolint:dupword
|
||||
content: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
// nolint:dupword
|
||||
expected: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
},
|
||||
{
|
||||
name: "Only nameserver matches",
|
||||
content: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "resolv.conf")
|
||||
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tempFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
||||
})
|
||||
}
|
||||
}
|
||||
159
client/internal/dns/file_repair_linux.go
Normal file
159
client/internal/dns/file_repair_linux.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
eventTypes = []fsnotify.Op{
|
||||
fsnotify.Create,
|
||||
fsnotify.Write,
|
||||
fsnotify.Remove,
|
||||
fsnotify.Rename,
|
||||
}
|
||||
)
|
||||
|
||||
type repairConfFn func([]string, string, *resolvConf) error
|
||||
|
||||
type repair struct {
|
||||
operationFile string
|
||||
updateFn repairConfFn
|
||||
watchDir string
|
||||
|
||||
inotify *fsnotify.Watcher
|
||||
inotifyWg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newRepair(operationFile string, updateFn repairConfFn) *repair {
|
||||
targetFile := targetFile(operationFile)
|
||||
return &repair{
|
||||
operationFile: targetFile,
|
||||
watchDir: path.Dir(targetFile),
|
||||
updateFn: updateFn,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
|
||||
if f.inotify != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("start to watch resolv.conf: %s", f.operationFile)
|
||||
inotify, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Errorf("failed to start inotify watcher for resolv.conf: %s", err)
|
||||
return
|
||||
}
|
||||
f.inotify = inotify
|
||||
|
||||
f.inotifyWg.Add(1)
|
||||
go func() {
|
||||
defer f.inotifyWg.Done()
|
||||
for event := range f.inotify.Events {
|
||||
if !f.isEventRelevant(event) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("%s changed, check if it is broken", f.operationFile)
|
||||
|
||||
rConf, err := parseResolvConfFile(f.operationFile)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("check resolv.conf parameters: %s", rConf)
|
||||
if !isNbParamsMissing(nbSearchDomains, nbNameserverIP, rConf) {
|
||||
log.Tracef("resolv.conf still correct, skip the update")
|
||||
continue
|
||||
}
|
||||
log.Info("broken params in resolv.conf, repairing it...")
|
||||
|
||||
err = f.inotify.Remove(f.watchDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to repair resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
err = f.inotify.Add(f.watchDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to re-add inotify watch for resolv.conf: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = f.inotify.Add(f.watchDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add inotify watch for resolv.conf: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *repair) stopWatchFileChanges() {
|
||||
if f.inotify == nil {
|
||||
return
|
||||
}
|
||||
err := f.inotify.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close resolv.conf inotify: %v", err)
|
||||
}
|
||||
f.inotifyWg.Wait()
|
||||
f.inotify = nil
|
||||
}
|
||||
|
||||
func (f *repair) isEventRelevant(event fsnotify.Event) bool {
|
||||
var ok bool
|
||||
for _, et := range eventTypes {
|
||||
if event.Has(et) {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if event.Name == f.operationFile {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs
|
||||
// check the NetBird related nameserver IP at the first place
|
||||
// check the NetBird related search domains in the search domains list
|
||||
func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *resolvConf) bool {
|
||||
if !isContains(nbSearchDomains, rConf.searchDomains) {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(rConf.nameServers) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if rConf.nameServers[0] != nbNameserverIP {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func targetFile(filename string) string {
|
||||
target, err := filepath.EvalSymlinks(filename)
|
||||
if err != nil {
|
||||
log.Errorf("evarl err: %s", err)
|
||||
}
|
||||
return target
|
||||
}
|
||||
175
client/internal/dns/file_repair_linux_test.go
Normal file
175
client/internal/dns/file_repair_linux_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func Test_newRepairtmp(t *testing.T) {
|
||||
type args struct {
|
||||
resolvConfContent string
|
||||
touchedConfContent string
|
||||
wantChange bool
|
||||
}
|
||||
tests := []args{
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something somethingelse`,
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
searchdomain netbird.cloud something`,
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
searchdomain something`,
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 10.0.0.1`,
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 8.8.8.8`,
|
||||
wantChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run("test", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
workDir := t.TempDir()
|
||||
operationFile := workDir + "/resolv.conf"
|
||||
err := os.WriteFile(operationFile, []byte(tt.resolvConfContent), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
var changed bool
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
updateFn := func([]string, string, *resolvConf) error {
|
||||
changed = true
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
r := newRepair(operationFile, updateFn)
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||
|
||||
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
r.stopWatchFileChanges()
|
||||
|
||||
if changed != tt.wantChange {
|
||||
t.Errorf("unexpected result: want: %v, got: %v", tt.wantChange, changed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newRepairSymlink(t *testing.T) {
|
||||
resolvConfContent := `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`
|
||||
|
||||
modifyContent := `nameserver 8.8.8.8`
|
||||
|
||||
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
err := os.WriteFile(tmpResolvConf, []byte(resolvConfContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpLink := filepath.Join(t.TempDir(), "symlink")
|
||||
err = os.Symlink(tmpResolvConf, tmpLink)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var changed bool
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
updateFn := func([]string, string, *resolvConf) error {
|
||||
changed = true
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
r := newRepair(tmpLink, updateFn)
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||
|
||||
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
r.stopWatchFileChanges()
|
||||
|
||||
if changed != true {
|
||||
t.Errorf("unexpected result: want: %v, got: %v", true, false)
|
||||
}
|
||||
}
|
||||
@@ -2,37 +2,40 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config hostDNSConfig) error
|
||||
applyDNSConfig(config HostDNSConfig) error
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
||||
}
|
||||
|
||||
type hostDNSConfig struct {
|
||||
domains []domainConfig
|
||||
routeAll bool
|
||||
serverIP string
|
||||
serverPort int
|
||||
type HostDNSConfig struct {
|
||||
Domains []DomainConfig `json:"domains"`
|
||||
RouteAll bool `json:"routeAll"`
|
||||
ServerIP string `json:"serverIP"`
|
||||
ServerPort int `json:"serverPort"`
|
||||
}
|
||||
|
||||
type domainConfig struct {
|
||||
disabled bool
|
||||
domain string
|
||||
matchOnly bool
|
||||
type DomainConfig struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
Domain string `json:"domain"`
|
||||
MatchOnly bool `json:"matchOnly"`
|
||||
}
|
||||
|
||||
type mockHostConfigurator struct {
|
||||
applyDNSConfigFunc func(config hostDNSConfig) error
|
||||
restoreHostDNSFunc func() error
|
||||
supportCustomPortFunc func() bool
|
||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
||||
restoreHostDNSFunc func() error
|
||||
supportCustomPortFunc func() bool
|
||||
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
if m.applyDNSConfigFunc != nil {
|
||||
return m.applyDNSConfigFunc(config)
|
||||
}
|
||||
@@ -53,40 +56,48 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||
if m.restoreUncleanShutdownDNSFunc != nil {
|
||||
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
|
||||
}
|
||||
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
|
||||
}
|
||||
|
||||
func newNoopHostMocker() hostManager {
|
||||
return &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
|
||||
restoreHostDNSFunc: func() error { return nil },
|
||||
supportCustomPortFunc: func() bool { return true },
|
||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
||||
restoreHostDNSFunc: func() error { return nil },
|
||||
supportCustomPortFunc: func() bool { return true },
|
||||
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
|
||||
config := hostDNSConfig{
|
||||
routeAll: false,
|
||||
serverIP: ip,
|
||||
serverPort: port,
|
||||
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig {
|
||||
config := HostDNSConfig{
|
||||
RouteAll: false,
|
||||
ServerIP: ip,
|
||||
ServerPort: port,
|
||||
}
|
||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||
if len(nsConfig.NameServers) == 0 {
|
||||
continue
|
||||
}
|
||||
if nsConfig.Primary {
|
||||
config.routeAll = true
|
||||
config.RouteAll = true
|
||||
}
|
||||
|
||||
for _, domain := range nsConfig.Domains {
|
||||
config.domains = append(config.domains, domainConfig{
|
||||
domain: strings.TrimSuffix(domain, "."),
|
||||
matchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(domain, "."),
|
||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
config.domains = append(config.domains, domainConfig{
|
||||
domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
matchOnly: false,
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
MatchOnly: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package dns
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type androidHostManager struct {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
func newHostManager() (hostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
}
|
||||
|
||||
func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,3 +20,7 @@ func (a androidHostManager) restoreHostDNS() error {
|
||||
func (a androidHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -32,7 +36,7 @@ type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
}
|
||||
|
||||
func newHostManager(_ WGIface) (hostManager, error) {
|
||||
func newHostManager() (hostManager, error) {
|
||||
return &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
@@ -42,21 +46,26 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
|
||||
if config.routeAll {
|
||||
err = s.addDNSSetupForAll(config.serverIP, config.serverPort)
|
||||
if config.RouteAll {
|
||||
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns setup for all: %w", err)
|
||||
}
|
||||
} else if s.primaryServiceID != "" {
|
||||
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("remote key from system config: %w", err)
|
||||
}
|
||||
s.primaryServiceID = ""
|
||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -64,37 +73,37 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, dConf.domain)
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
if len(matchDomains) != 0 {
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort)
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing match domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(matchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort)
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing search domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(searchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -117,7 +126,11 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
_, err := runSystemConfigCommand(wrapCommand(lines))
|
||||
if err != nil {
|
||||
log.Errorf("got an error while cleaning the system configuration: %s", err)
|
||||
return err
|
||||
return fmt.Errorf("clean system: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -127,7 +140,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("remove key: %w", err)
|
||||
}
|
||||
|
||||
delete(s.createdKeys, key)
|
||||
@@ -138,7 +151,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
@@ -151,7 +164,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po
|
||||
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error {
|
||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
@@ -176,33 +189,37 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
||||
|
||||
_, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err)
|
||||
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
||||
if primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key")
|
||||
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||
}
|
||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
|
||||
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
|
||||
s.primaryServiceID = primaryServiceKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
line := buildCommandLine("show", globalIPv4State, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
|
||||
b, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
log.Error("got error while sending the command: ", err)
|
||||
return "", ""
|
||||
return "", "", fmt.Errorf("sending the command: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||
primaryService := ""
|
||||
router := ""
|
||||
@@ -215,7 +232,11 @@ func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
}
|
||||
}
|
||||
return primaryService, router
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return primaryService, router, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return primaryService, router, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||
@@ -226,7 +247,14 @@ func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, e
|
||||
stdinCommands := wrapCommand(addDomainCommand)
|
||||
_, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while applying dns setup, error: %s", err)
|
||||
return fmt.Errorf("applying dns setup, error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -264,7 +292,7 @@ func runSystemConfigCommand(command string) ([]byte, error) {
|
||||
cmd.Stdin = strings.NewReader(command)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err)
|
||||
return nil, fmt.Errorf("running system configuration command: \"%s\", error: %w", command, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
43
client/internal/dns/host_ios.go
Normal file
43
client/internal/dns/host_ios.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type iosHostManager struct {
|
||||
dnsManager IosDnsManager
|
||||
config HostDNSConfig
|
||||
}
|
||||
|
||||
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
||||
return &iosHostManager{
|
||||
dnsManager: dnsManager,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||
jsonData, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
jsonString := string(jsonData)
|
||||
log.Debugf("Applying DNS settings: %s", jsonString)
|
||||
a.dnsManager.ApplyDns(jsonString)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
@@ -4,17 +4,15 @@ package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
)
|
||||
|
||||
const (
|
||||
netbirdManager osManagerType = iota
|
||||
fileManager
|
||||
@@ -23,15 +21,55 @@ const (
|
||||
resolvConfManager
|
||||
)
|
||||
|
||||
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
|
||||
|
||||
type osManagerType int
|
||||
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
func newOsManagerType(osManager string) (osManagerType, error) {
|
||||
switch osManager {
|
||||
case "netbird":
|
||||
return fileManager, nil
|
||||
case "file":
|
||||
return netbirdManager, nil
|
||||
case "networkManager":
|
||||
return networkManager, nil
|
||||
case "systemd":
|
||||
return systemdManager, nil
|
||||
case "resolvconf":
|
||||
return resolvConfManager, nil
|
||||
default:
|
||||
return 0, ErrUnknownOsManagerType
|
||||
}
|
||||
}
|
||||
|
||||
func (t osManagerType) String() string {
|
||||
switch t {
|
||||
case netbirdManager:
|
||||
return "netbird"
|
||||
case fileManager:
|
||||
return "file"
|
||||
case networkManager:
|
||||
return "networkManager"
|
||||
case systemdManager:
|
||||
return "systemd"
|
||||
case resolvConfManager:
|
||||
return "resolvconf"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("discovered mode is: %d", osManager)
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
return newHostManagerFromType(wgInterface, osManager)
|
||||
}
|
||||
|
||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
|
||||
switch osManager {
|
||||
case networkManager:
|
||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||
@@ -45,12 +83,15 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err)
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
@@ -65,11 +106,14 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
return systemdManager, nil
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
@@ -85,5 +129,26 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return fileManager, nil
|
||||
}
|
||||
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
for _, ns := range rConf.nameServers {
|
||||
if ns == "127.0.0.53" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -9,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match"
|
||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||
dnsPolicyConfigVersionKey = "Version"
|
||||
dnsPolicyConfigVersionValue = 2
|
||||
dnsPolicyConfigNameKey = "Name"
|
||||
@@ -19,7 +21,7 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
)
|
||||
@@ -34,29 +36,38 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newHostManagerWithGuid(guid)
|
||||
}
|
||||
|
||||
func newHostManagerWithGuid(guid string) (hostManager, error) {
|
||||
return ®istryConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *registryConfigurator) supportCustomPort() bool {
|
||||
func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
if config.routeAll {
|
||||
err = r.addDNSSetupForAll(config.serverIP)
|
||||
if config.RouteAll {
|
||||
err = r.addDNSSetupForAll(config.ServerIP)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP)
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(r.guid); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -64,28 +75,28 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if !dConf.matchOnly {
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
if !dConf.MatchOnly {
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
matchDomains = append(matchDomains, "."+dConf.domain)
|
||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
err = r.addDNSMatchPolicy(matchDomains, config.serverIP)
|
||||
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
} else {
|
||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
|
||||
err = r.updateSearchDomains(searchDomains)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -94,7 +105,7 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed with error: %s", err)
|
||||
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
|
||||
}
|
||||
r.routingAll = true
|
||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||
@@ -106,33 +117,33 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
||||
if err == nil {
|
||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||
}
|
||||
|
||||
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
||||
@@ -141,18 +152,25 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreHostDNS() error {
|
||||
err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
log.Errorf("remove registry key from dns policy config: %s", err)
|
||||
}
|
||||
|
||||
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey)
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding search domain failed with error: %s", err)
|
||||
return fmt.Errorf("adding search domain failed with error: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
||||
@@ -163,13 +181,13 @@ func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
|
||||
regKey, err := r.getInterfaceRegistryKey()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
defer closer(regKey)
|
||||
|
||||
err = regKey.SetStringValue(key, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err)
|
||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -178,13 +196,13 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
|
||||
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
|
||||
regKey, err := r.getInterfaceRegistryKey()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
defer closer(regKey)
|
||||
|
||||
err = regKey.DeleteValue(propertyKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err)
|
||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -197,20 +215,33 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
k.Close()
|
||||
defer closer(k)
|
||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closer(closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
log.Errorf("failed to close: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
fullRecord, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
fullRecord.Header().Rdlength = record.Len()
|
||||
@@ -71,3 +71,5 @@ func buildRecordKey(name string, class, qType uint16) string {
|
||||
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
|
||||
return key
|
||||
}
|
||||
|
||||
func (d *localResolver) probeAvailability() {}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
|
||||
@@ -33,7 +33,7 @@ func (m *MockServer) DnsIP() string {
|
||||
}
|
||||
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||
//TODO implement me
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
@@ -48,3 +48,7 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
func (m *MockServer) ProbeAvailability() {
|
||||
}
|
||||
@@ -5,15 +5,18 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,9 +43,13 @@ const (
|
||||
networkManagerDbusPrimaryDNSPriority int32 = -500
|
||||
networkManagerDbusWithMatchDomainPriority int32 = 0
|
||||
networkManagerDbusSearchDomainOnlyPriority int32 = 50
|
||||
supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28"
|
||||
)
|
||||
|
||||
var supportedNetworkManagerVersionConstraints = []string{
|
||||
">= 1.16, < 1.27",
|
||||
">= 1.44, < 1.45",
|
||||
}
|
||||
|
||||
type networkManagerDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
@@ -70,19 +77,19 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get nm dbus: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var s string
|
||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s)
|
||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface).Store(&s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("call: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name())
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
@@ -93,17 +100,17 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
|
||||
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
||||
}
|
||||
|
||||
connSettings.cleanDeprecatedSettings()
|
||||
|
||||
dnsIP, err := netip.ParseAddr(config.serverIP)
|
||||
dnsIP, err := netip.ParseAddr(config.ServerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
||||
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||
}
|
||||
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
||||
@@ -111,56 +118,70 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
||||
priority := networkManagerDbusSearchDomainOnlyPriority
|
||||
switch {
|
||||
case config.routeAll:
|
||||
case config.RouteAll:
|
||||
priority = networkManagerDbusPrimaryDNSPriority
|
||||
newDomainList = append(newDomainList, "~.")
|
||||
if !n.routingAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
case len(matchDomains) > 0:
|
||||
priority = networkManagerDbusWithMatchDomainPriority
|
||||
}
|
||||
|
||||
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
n.routingAll = false
|
||||
}
|
||||
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||
|
||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||
// The file content itself is not important for network-manager restoration
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||
err = n.reApplyConnectionSettings(connSettings, configVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err)
|
||||
return fmt.Errorf("reapplying the connection with new settings, error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
||||
// once the interface is gone network manager cleans all config associated with it
|
||||
return n.deleteConnectionSettings()
|
||||
if err := n.deleteConnectionSettings(); err != nil {
|
||||
return fmt.Errorf("delete connection settings: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
||||
return nil, 0, fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
@@ -175,7 +196,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
|
||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag,
|
||||
networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err)
|
||||
return nil, 0, fmt.Errorf("calling GetAppliedConnection method with context, err: %w", err)
|
||||
}
|
||||
|
||||
return connSettings, configVersion, nil
|
||||
@@ -184,7 +205,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
|
||||
func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
@@ -194,7 +215,7 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
|
||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag,
|
||||
connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while calling ReApply method with context, err: %s", err)
|
||||
return fmt.Errorf("calling ReApply method with context, err: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -203,21 +224,34 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
|
||||
func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// this call is required to remove the device for DNS cleanup, even if it fails
|
||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while calling delete method with context, err: %s", err)
|
||||
var dbusErr dbus.Error
|
||||
if errors.As(err, &dbusErr) && dbusErr.Name == dbus.ErrMsgUnknownMethod.Name {
|
||||
// interface is gone already
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("calling delete method with context, err: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := n.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNetworkManagerSupported() bool {
|
||||
return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode()
|
||||
}
|
||||
@@ -249,13 +283,13 @@ func isNetworkManagerSupportedMode() bool {
|
||||
func getNetworkManagerDNSProperty(property string, store any) error {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the network manager dns manager object, error: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
v, e := obj.GetProperty(property)
|
||||
if e != nil {
|
||||
return fmt.Errorf("got an error getting property %s: %v", property, e)
|
||||
return fmt.Errorf("getting property %s: %w", property, e)
|
||||
}
|
||||
|
||||
return v.Store(store)
|
||||
@@ -277,15 +311,26 @@ func isNetworkManagerSupportedVersion() bool {
|
||||
}
|
||||
versionValue, err := parseVersion(value.Value().(string))
|
||||
if err != nil {
|
||||
log.Errorf("nm: parse version: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint)
|
||||
if err != nil {
|
||||
return false
|
||||
var supported bool
|
||||
for _, constraint := range supportedNetworkManagerVersionConstraints {
|
||||
constr, err := version.NewConstraint(constraint)
|
||||
if err != nil {
|
||||
log.Errorf("nm: create constraint: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if met := constr.Check(versionValue); met {
|
||||
supported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return constraints.Check(versionValue)
|
||||
log.Debugf("network manager constraints [%s] met: %t", strings.Join(supportedNetworkManagerVersionConstraints, " | "), supported)
|
||||
return supported
|
||||
}
|
||||
|
||||
func parseVersion(inputVersion string) (*version.Version, error) {
|
||||
|
||||
@@ -52,6 +52,6 @@ func (n *notifier) notify() {
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged()
|
||||
l.OnNetworkChanged("")
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package dns
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -21,17 +22,17 @@ type resolvconf struct {
|
||||
}
|
||||
|
||||
// supported "openresolv" only
|
||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf")
|
||||
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
|
||||
resolvConfEntries, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface.Name(),
|
||||
originalSearchDomains: originalSearchDomains,
|
||||
originalNameServers: nameServers,
|
||||
othersConfigs: others,
|
||||
ifaceName: wgInterface,
|
||||
originalSearchDomains: resolvConfEntries.searchDomains,
|
||||
originalNameServers: resolvConfEntries.nameServers,
|
||||
othersConfigs: resolvConfEntries.others,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -39,12 +40,12 @@ func (r *resolvconf) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
if !config.routeAll {
|
||||
if !config.RouteAll {
|
||||
err = r.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("restore host dns: %s", err)
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
||||
}
|
||||
@@ -52,14 +53,21 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
searchDomainList := searchDomains(config)
|
||||
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
||||
|
||||
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
append([]string{config.serverIP}, r.originalNameServers...),
|
||||
r.othersConfigs)
|
||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||
options)
|
||||
|
||||
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
err = r.applyConfig(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("apply config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
|
||||
@@ -67,20 +75,34 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
}
|
||||
|
||||
func (r *resolvconf) restoreHostDNS() error {
|
||||
// openresolv only, debian resolvconf doesn't support "-f"
|
||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while removing resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
// openresolv only, debian resolvconf doesn't support "-x"
|
||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
cmd.Stdin = &content
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -31,10 +31,13 @@ func (r *responseWriter) RemoteAddr() net.Addr {
|
||||
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
buff, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
_, err = r.Write(buff)
|
||||
return err
|
||||
|
||||
if _, err := r.Write(buff); err != nil {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes a raw buffer back to the client.
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@@ -19,6 +21,11 @@ type ReadyListener interface {
|
||||
OnReady()
|
||||
}
|
||||
|
||||
// IosDnsManager is a dns manager interface for iOS
|
||||
type IosDnsManager interface {
|
||||
ApplyDns(string)
|
||||
}
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
Initialize() error
|
||||
@@ -27,6 +34,7 @@ type Server interface {
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@@ -43,7 +51,7 @@ type DefaultServer struct {
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
currentConfig HostDNSConfig
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
@@ -52,11 +60,15 @@ type DefaultServer struct {
|
||||
|
||||
// make sense on mobile only
|
||||
searchDomainNotifier *notifier
|
||||
iosDnsManager IosDnsManager
|
||||
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
stop()
|
||||
probeAvailability()
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
@@ -65,7 +77,12 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -82,13 +99,20 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
|
||||
func NewDefaultServerPermanentUpstream(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
hostsDnsList []string,
|
||||
config nbdns.Config,
|
||||
listener listener.NetworkChangeListener,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.permanent = true
|
||||
ds.hostsDnsList = hostsDnsList
|
||||
ds.addHostRootZone()
|
||||
@@ -99,7 +123,19 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface,
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||
// NewDefaultServerIos returns a new dns server. It optimized for ios
|
||||
func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
@@ -109,7 +145,8 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
|
||||
return defaultServer
|
||||
@@ -127,12 +164,15 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
if s.permanent {
|
||||
err = s.service.Listen()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("service listen: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.hostManager, err = newHostManager(s.wgInterface)
|
||||
return
|
||||
s.hostManager, err = s.initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DnsIP returns the DNS resolver server IP address
|
||||
@@ -210,7 +250,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("apply configuration: %w", err)
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
@@ -223,20 +263,28 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
func (s *DefaultServer) SearchDomains() []string {
|
||||
var searchDomains []string
|
||||
|
||||
for _, dConf := range s.currentConfig.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range s.currentConfig.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
if dConf.MatchOnly {
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
return searchDomains
|
||||
}
|
||||
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
mux.probeAvailability()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// is the service should be Disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
@@ -262,7 +310,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.routeAll = false
|
||||
hostUpdate.RouteAll = false
|
||||
}
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
||||
@@ -273,6 +321,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
}
|
||||
|
||||
s.updateNSGroupStates(update.NameServerGroups)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -312,7 +362,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
continue
|
||||
}
|
||||
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||
@@ -362,6 +421,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
@@ -430,14 +490,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
) (deactivate func(error), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
deactivate = func(err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
l.Info("Temporarily deactivating nameservers group due to timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
@@ -445,29 +505,32 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.routeAll = false
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.service.DeregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.service.DeregisterMux(item.Domain)
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, err, false)
|
||||
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.service.RegisterMux(domain, handler)
|
||||
}
|
||||
|
||||
@@ -475,17 +538,29 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Debug("reactivate temporary disabled nameserver group")
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.routeAll = true
|
||||
s.currentConfig.RouteAll = true
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) addHostRootZone() {
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return
|
||||
}
|
||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||
for n, ua := range s.hostsDnsList {
|
||||
a, err := netip.ParseAddr(ua)
|
||||
@@ -501,7 +576,50 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
||||
}
|
||||
handler.deactivate = func() {}
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
var states []peer.NSGroupState
|
||||
|
||||
for _, group := range groups {
|
||||
var servers []string
|
||||
for _, ns := range group.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
|
||||
state := peer.NSGroupState{
|
||||
ID: generateGroupKey(group),
|
||||
Servers: servers,
|
||||
Domains: group.Domains,
|
||||
// The probe will determine the state, default enabled
|
||||
Enabled: true,
|
||||
Error: nil,
|
||||
}
|
||||
states = append(states, state)
|
||||
}
|
||||
s.statusRecorder.UpdateDNSStates(states)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
|
||||
states := s.statusRecorder.GetDNSStates()
|
||||
id := generateGroupKey(nsGroup)
|
||||
for i, state := range states {
|
||||
if state.ID == id {
|
||||
states[i].Enabled = enabled
|
||||
states[i].Error = err
|
||||
break
|
||||
}
|
||||
}
|
||||
s.statusRecorder.UpdateDNSStates(states)
|
||||
}
|
||||
|
||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
var servers []string
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
||||
}
|
||||
|
||||
5
client/internal/dns/server_android.go
Normal file
5
client/internal/dns/server_android.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
}
|
||||
7
client/internal/dns/server_darwin.go
Normal file
7
client/internal/dns/server_darwin.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
}
|
||||
5
client/internal/dns/server_ios.go
Normal file
5
client/internal/dns/server_ios.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager(s.iosDnsManager)
|
||||
}
|
||||
7
client/internal/dns/server_linux.go
Normal file
7
client/internal/dns/server_linux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager(s.wgInterface.Name())
|
||||
}
|
||||
@@ -12,8 +12,10 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -58,6 +60,10 @@ func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
|
||||
return iface.WGStats{}, nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -250,11 +256,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -268,7 +275,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -331,7 +338,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
@@ -368,7 +376,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -463,7 +471,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
@@ -527,23 +535,24 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
hostManager: hostManager,
|
||||
currentConfig: hostDNSConfig{
|
||||
domains: []domainConfig{
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
{false, "domain1", false},
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
statusRecorder: &peer.Status{},
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.domains {
|
||||
if item.disabled {
|
||||
for _, item := range config.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
return nil
|
||||
@@ -556,14 +565,14 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
},
|
||||
}, nil)
|
||||
|
||||
deactivate()
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.domains {
|
||||
if item.disabled {
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
@@ -573,11 +582,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
for _, item := range server.currentConfig.domains {
|
||||
if item.disabled {
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
@@ -594,7 +603,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
|
||||
var dnsList []string
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -618,7 +627,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -710,7 +719,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -782,7 +791,8 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
|
||||
5
client/internal/dns/server_windows.go
Normal file
5
client/internal/dns/server_windows.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager(s.wgInterface)
|
||||
}
|
||||
@@ -28,7 +28,7 @@ type serviceViaListener struct {
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
listenIP string
|
||||
listenPort int
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
@@ -63,18 +63,9 @@ func (s *serviceViaListener) Listen() error {
|
||||
s.listenIP, s.listenPort, err = s.evalListenAddress()
|
||||
if err != nil {
|
||||
log.Errorf("failed to eval runtime address: %s", err)
|
||||
return err
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||
|
||||
if s.shouldApplyPortFwd() {
|
||||
s.ebpfService = ebpf.GetEbpfManagerInstance()
|
||||
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
|
||||
if err != nil {
|
||||
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
|
||||
s.ebpfService = nil
|
||||
}
|
||||
}
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
@@ -128,7 +119,7 @@ func (s *serviceViaListener) RuntimePort() int {
|
||||
if s.ebpfService != nil {
|
||||
return defaultPort
|
||||
} else {
|
||||
return s.listenPort
|
||||
return int(s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,54 +131,112 @@ func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
// pick a random port on WG interface for eBPF, if not success
|
||||
// check the 5053 port availability on WG interface or lo without eBPF usage,
|
||||
func (s *serviceViaListener) evalListenAddress() (string, uint16, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr().String(), s.customAddr.Port(), nil
|
||||
}
|
||||
|
||||
ip, ok := s.testFreePort(defaultPort)
|
||||
if ok {
|
||||
return ip, defaultPort, nil
|
||||
}
|
||||
|
||||
ebpfSrv, port, ok := s.tryToUseeBPF()
|
||||
if ok {
|
||||
s.ebpfService = ebpfSrv
|
||||
return s.wgInterface.Address().IP.String(), port, nil
|
||||
}
|
||||
|
||||
ip, ok = s.testFreePort(customPort)
|
||||
if ok {
|
||||
return ip, customPort, nil
|
||||
}
|
||||
|
||||
return "", 0, fmt.Errorf("failed to find a free port for DNS server")
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) testFreePort(port int) (string, bool) {
|
||||
var ips []string
|
||||
if runtime.GOOS != "darwin" {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
ips = []string{s.wgInterface.Address().IP.String(), defaultIP, customIP}
|
||||
} else {
|
||||
ips = []string{defaultIP, customIP}
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
|
||||
for _, ip := range ips {
|
||||
if !s.tryToBind(ip, port) {
|
||||
continue
|
||||
}
|
||||
|
||||
return ip, true
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||
}
|
||||
|
||||
return s.getFirstListenerAvailable()
|
||||
}
|
||||
|
||||
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
|
||||
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
|
||||
// resolution won't work.
|
||||
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
|
||||
// traffic on port 53 and forward it to a local DNS server running on 5053.
|
||||
func (s *serviceViaListener) shouldApplyPortFwd() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
func (s *serviceViaListener) tryToBind(ip string, port int) bool {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.customAddr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.listenPort == defaultPort {
|
||||
return false
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// tryToUseeBPF decides whether to apply eBPF program to capture DNS traffic on port 53.
|
||||
// This is needed because on some operating systems if we start a DNS server not on a default port 53,
|
||||
// the domain name resolution won't work. So, in case we are running on Linux and picked a free
|
||||
// port we should fall back to the eBPF solution that will capture traffic on port 53 and forward
|
||||
// it to a local DNS server running on the chosen port.
|
||||
func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
port, err := s.generateFreePort() //nolint:staticcheck,unused
|
||||
if err != nil {
|
||||
log.Warnf("failed to generate a free port for eBPF DNS forwarder server: %s", err)
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port))
|
||||
if err != nil {
|
||||
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
return ebpfSrv, port, true
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) generateFreePort() (uint16, error) {
|
||||
ok := s.tryToBind(s.wgInterface.Address().IP.String(), customPort)
|
||||
if ok {
|
||||
return customPort, nil
|
||||
}
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("0.0.0.0:0"))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Debugf("failed to bind random port for DNS: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
addrPort := netip.MustParseAddrPort(probeListener.LocalAddr().String()) // might panic if address is incorrect
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Debugf("failed to free up DNS port: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
return addrPort.Port(), nil
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func (s *serviceViaMemory) Listen() error {
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -30,6 +31,8 @@ const (
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
)
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
@@ -52,22 +55,22 @@ type systemdDbusLinkDomainsInput struct {
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get interface: %w", err)
|
||||
}
|
||||
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get dbus resolved dest: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
var s string
|
||||
err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get dbus link method: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||
@@ -81,10 +84,10 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
parsedIP, err := netip.ParseAddr(config.serverIP)
|
||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
||||
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||
}
|
||||
ipAs4 := parsedIP.As4()
|
||||
defaultLinkInput := systemdDbusDNSInput{
|
||||
@@ -93,7 +96,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
}
|
||||
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err)
|
||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -101,27 +104,27 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
matchDomains []string
|
||||
domainsInput []systemdDbusLinkDomainsInput
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dns.Fqdn(dConf.domain),
|
||||
MatchOnly: dConf.matchOnly,
|
||||
Domain: dns.Fqdn(dConf.Domain),
|
||||
MatchOnly: dConf.MatchOnly,
|
||||
})
|
||||
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, dConf.domain)
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
|
||||
if config.routeAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||
if config.RouteAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting link as default dns router, failed with error: %s", err)
|
||||
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: nbdns.RootZone,
|
||||
@@ -129,7 +132,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
})
|
||||
s.routingAll = true
|
||||
} else if s.routingAll {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
|
||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||
// The file content itself is not important for systemd restoration
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||
@@ -143,7 +152,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error {
|
||||
err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting domains configuration failed with error: %s", err)
|
||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||
}
|
||||
return s.flushCaches()
|
||||
}
|
||||
@@ -153,17 +162,29 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// this call is required for DNS cleanup, even if it fails
|
||||
err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to revert link configuration, got error: %s", err)
|
||||
var dbusErr dbus.Error
|
||||
if errors.As(err, &dbusErr) && dbusErr.Name == dbusErrorUnknownObject {
|
||||
// interface is gone already
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return s.flushCaches()
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the object %s, err: %s", systemdDbusObjectNode, err)
|
||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||
}
|
||||
defer closeConn()
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||
@@ -171,7 +192,7 @@ func (s *systemdDbusConfigurator) flushCaches() error {
|
||||
|
||||
err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while calling the FlushCaches method with context, err: %s", err)
|
||||
return fmt.Errorf("calling the FlushCaches method with context, err: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -180,7 +201,7 @@ func (s *systemdDbusConfigurator) flushCaches() error {
|
||||
func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the object, err: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the object, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
@@ -194,22 +215,29 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while calling command with context, err: %s", err)
|
||||
return fmt.Errorf("calling command with context, err: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via systemd: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSystemdDbusProperty(property string, store any) error {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the systemd dns manager object, error: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the systemd dns manager object, error: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
v, e := obj.GetProperty(property)
|
||||
if e != nil {
|
||||
return fmt.Errorf("got an error getting property %s: %v", property, e)
|
||||
return fmt.Errorf("getting property %s: %w", property, e)
|
||||
}
|
||||
|
||||
return v.Store(store)
|
||||
|
||||
5
client/internal/dns/unclean_shutdown_android.go
Normal file
5
client/internal/dns/unclean_shutdown_android.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func CheckUncleanShutdown(string) error {
|
||||
return nil
|
||||
}
|
||||
59
client/internal/dns/unclean_shutdown_darwin.go
Normal file
59
client/internal/dns/unclean_shutdown_darwin.go
Normal file
@@ -0,0 +1,59 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
|
||||
|
||||
func CheckUncleanShutdown(string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// no file -> clean shutdown
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
|
||||
|
||||
manager, err := newHostManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createUncleanShutdownIndicator() error {
|
||||
dir := filepath.Dir(fileUncleanShutdownFileLocation)
|
||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
|
||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeUncleanShutdownIndicator() error {
|
||||
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user