Compare commits

..

7 Commits

Author SHA1 Message Date
Maycon Santos
d9fa28d8a0 test postgres 2023-08-02 18:56:06 +02:00
Maycon Santos
144ac868e0 create empty files 2023-08-02 18:32:31 +02:00
Maycon Santos
b70339d3bd use check_jq 2023-08-02 12:32:43 +02:00
Maycon Santos
ee890971a3 update texts and expose errors to stderr 2023-08-02 12:19:35 +02:00
Maycon Santos
eeb1b619b7 Automate Zitadel IDP configuration (#1006)
Add automated zitadel configuration and netbird using a single script

run infra files test only when they change or when pushed to main

run releases only on changes to related files or when pushed to main

push install and automated getting started script to releases page
2023-08-01 23:31:14 +02:00
Maycon Santos
4b47f6b23c Merge branch 'main' into update-getting-started-flow 2023-07-30 16:39:49 +02:00
Maycon Santos
5883e019c9 update readme 2023-07-07 17:05:21 +02:00
394 changed files with 11572 additions and 20983 deletions

View File

@@ -1,15 +0,0 @@
FROM golang:1.21-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\
gettext-base=0.21-4 \
iptables=1.8.7-1 \
libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@latest
WORKDIR /app

View File

@@ -1,20 +0,0 @@
{
"name": "NetBird",
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/go:1": {
"version": "1.21"
}
},
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
"capAdd": [
"NET_ADMIN",
"SYS_ADMIN",
"SYS_RESOURCE"
],
"privileged": true
}

1
.gitattributes vendored
View File

@@ -1 +0,0 @@
*.go text eol=lf

View File

@@ -1,48 +0,0 @@
name: Android build validation
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: "1.21.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v3
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -12,20 +12,17 @@ concurrency:
jobs:
test:
strategy:
matrix:
store: ['jsonfile', 'sqlite']
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21.x"
go-version: "1.20.x"
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}
@@ -36,4 +33,4 @@ jobs:
run: go mod tidy
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
run: go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...

View File

@@ -15,17 +15,16 @@ jobs:
strategy:
matrix:
arch: ['386','amd64']
store: ['jsonfile', 'sqlite']
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21.x"
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +32,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
@@ -42,18 +41,19 @@ jobs:
run: go mod tidy
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21.x"
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -61,10 +61,10 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v2
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: Install modules
run: go mod tidy
@@ -82,7 +82,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
@@ -95,17 +95,15 @@ jobs:
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -23,13 +23,13 @@ jobs:
uses: actions/setup-go@v4
id: go
with:
go-version: "1.21.x"
go-version: "1.20.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
@@ -39,9 +39,7 @@ jobs:
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: choco install -y sysinternals
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
@@ -49,4 +47,4 @@ jobs:
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt
run: Get-Content test-out.txt

View File

@@ -1,48 +1,21 @@
name: golangci-lint
on: [pull_request]
permissions:
contents: read
pull-requests: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
codespell:
name: codespell
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta
skip: go.mod,go.sum
only_warn: 1
golangci:
strategy:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
name: lint
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v3
- uses: actions/checkout@v2
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21.x"
cache: false
go-version: "1.20.x"
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v2
with:
version: latest
args: --timeout=12m
args: --timeout=6m

View File

@@ -1,36 +0,0 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -0,0 +1,60 @@
name: Test installation Darwin
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
env:
SKIP_UI_APP: true
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
install-all:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
echo "Error: NetBird UI is not installed"
exit 1
fi

View File

@@ -0,0 +1,38 @@
name: Test installation Linux
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest
strategy:
matrix:
check_bin_install: [true, false]
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename apt package
if: ${{ matrix.check_bin_install }}
run: |
sudo mv /usr/bin/apt /usr/bin/apt.bak
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi

View File

@@ -17,10 +17,9 @@ on:
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
- 'client/ui/**'
env:
SIGN_PIPE_VER: "v0.0.11"
SIGN_PIPE_VER: "v0.0.8"
GORELEASER_VER: "v1.14.1"
concurrency:
@@ -30,32 +29,25 @@ concurrency:
jobs:
release:
runs-on: ubuntu-latest
env:
flags: ""
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21"
cache: false
go-version: "1.20"
-
name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v1
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-releaser-
${{ runner.os }}-go-
-
name: Install modules
run: go mod tidy
@@ -64,10 +56,10 @@ jobs:
run: git --no-pager diff --exit-code
-
name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@v1
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v1
-
name: Login to Docker hub
if: github.event_name != 'pull_request'
@@ -90,10 +82,10 @@ jobs:
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@v2
with:
version: ${{ env.GORELEASER_VER }}
args: release --rm-dist ${{ env.flags }}
args: release --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -101,7 +93,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v2
with:
name: release
path: dist/
@@ -110,27 +102,22 @@ jobs:
release_ui:
runs-on: ubuntu-latest
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21"
cache: false
go-version: "1.20"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v1
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-
${{ runner.os }}-ui-go-
- name: Install modules
run: go mod tidy
@@ -145,17 +132,17 @@ jobs:
- name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@v2
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui.yaml --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v2
with:
name: release-ui
path: dist/
@@ -164,44 +151,39 @@ jobs:
release_ui_darwin:
runs-on: macos-11
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21"
cache: false
go-version: "1.20"
-
name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v1
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
path: ~/go/pkg/mod
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin-
${{ runner.os }}-ui-go-
-
name: Install modules
run: go mod tidy
-
name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@v2
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v2
with:
name: release-ui-darwin
path: dist/

View File

@@ -1,22 +0,0 @@
name: sync main
on:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_main:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -1,23 +0,0 @@
name: sync tag
on:
push:
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -8,8 +8,7 @@ on:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -26,12 +25,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v2
with:
go-version: "1.21.x"
go-version: "1.20.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -58,11 +57,9 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
working-directory: infrastructure_files/artifacts
working-directory: infrastructure_files
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
@@ -84,13 +81,8 @@ jobs:
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
run: |
set -x
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
@@ -100,53 +92,27 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
- name: Install modules
run: go mod tidy
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
run: CGO_ENABLED=0 go build -o netbird-signal main.go
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest .
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
- name: run docker compose up
working-directory: infrastructure_files/artifacts
working-directory: infrastructure_files
run: |
docker-compose up -d
sleep 5
@@ -155,9 +121,9 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
test $count -eq 4
working-directory: infrastructure_files/artifacts
working-directory: infrastructure_files
test-getting-started-script:
runs-on: ubuntu-latest
@@ -169,7 +135,7 @@ jobs:
uses: actions/checkout@v3
- name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen
run: test -f Caddyfile
@@ -178,11 +144,8 @@ jobs:
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
run: test -f turnserver.conf
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen
run: test -f dashboard.env
run: test -f dashboard.env

21
.gitignore vendored
View File

@@ -6,20 +6,11 @@ bin/
.env
conf.json
http-cmds.sh
setup.env
infrastructure_files/**/Caddyfile
infrastructure_files/**/dashboard.env
infrastructure_files/**/zitadel.env
infrastructure_files/**/management.json
infrastructure_files/**/management-*.json
infrastructure_files/**/docker-compose.yml
infrastructure_files/**/openid-configuration.json
infrastructure_files/**/turnserver.conf
infrastructure_files/**/management.json.bkp.**
infrastructure_files/**/management-*.json.bkp.**
infrastructure_files/**/docker-compose.yml.bkp.**
infrastructure_files/**/openid-configuration.json.bkp.**
infrastructure_files/**/turnserver.conf.bkp.**
infrastructure_files/management.json
infrastructure_files/management-*.json
infrastructure_files/docker-compose.yml
infrastructure_files/openid-configuration.json
infrastructure_files/turnserver.conf
management/management
client/client
client/client.exe
@@ -28,5 +19,3 @@ client/.distfiles/
infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
*.db

View File

@@ -1,123 +0,0 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- dupword # dupword checks for duplicate words in the source code
- durationcheck # durationcheck checks for two durations multiplied together
- forbidigo # forbidigo forbids identifiers
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gosec # inspects source code for security problems
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
- misspell # misspess finds commonly misspelled English words in comments
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
# allow fmt
- path: management/cmd/root\.go
linters: forbidigo
- path: signal/cmd/root\.go
linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil

View File

@@ -378,11 +378,6 @@ uploads:
username: dev@wiretrustee.com
method: PUT
checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -54,7 +54,7 @@ nfpms:
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png
- src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -71,7 +71,7 @@ nfpms:
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png
- src: client/ui/disconnected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -91,4 +91,4 @@ uploads:
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com
method: PUT
method: PUT

View File

@@ -19,11 +19,11 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
- [Development setup](#development-setup)
- [Requirements](#requirements)
- [Local NetBird setup](#local-netbird-setup)
- [Dev Container Support](#dev-container-support)
- [Build and start](#build-and-start)
- [Test suite](#test-suite)
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
- [Other project repositories](#other-project-repositories)
- [Checklist before submitting a new node](#checklist-before-submitting-a-new-node)
- [Contributor License Agreement](#contributor-license-agreement)
## Code of conduct
@@ -70,7 +70,7 @@ dependencies are installed. Here is a short guide on how that can be done.
### Requirements
#### Go 1.21
#### Go 1.19
Follow the installation guide from https://go.dev/
@@ -136,61 +136,18 @@ checked out and set up:
go mod tidy
```
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.
Dev containers provide a consistent and isolated development environment, making it easier for contributors to get started quickly. Follow the steps below to set up NetBird in a dev container.
#### 1. Prerequisites:
* Install Docker on your machine: [Docker Installation Guide](https://docs.docker.com/get-docker/)
* Install Visual Studio Code: [VS Code Installation Guide](https://code.visualstudio.com/download)
* If you prefer JetBrains Goland please follow this [manual](https://www.jetbrains.com/help/go/connect-to-devcontainer.html)
#### 2. Clone the Repository:
Clone the repository following previous [Local NetBird setup](#local-netbird-setup).
#### 3. Open in project in IDE of your choice:
**VScode**:
Open the project folder in Visual Studio Code:
```bash
code .
```
When you open the project in VS Code, it will detect the presence of a dev container configuration.
Click on the green "Reopen in Container" button in the bottom-right corner of VS Code.
**Goland**:
Open GoLand and select `"File" > "Open"` to open the NetBird project folder.
GoLand will detect the dev container configuration and prompt you to open the project in the container. Accept the prompt.
#### 4. Wait for the Container to Build:
VsCode or GoLand will use the specified Docker image to build the dev container. This might take some time, depending on your internet connection.
#### 6. Development:
Once the container is built, you can start developing within the dev container. All the necessary dependencies and configurations are set up within the container.
### Build and start
#### Client
> Windows clients have a Wireguard driver requirement. We provide a bash script that can be executed in WLS 2 with docker support [wireguard_nt.sh](/client/wireguard_nt.sh).
To start NetBird, execute:
```
cd client
CGO_ENABLED=0 go build .
# bash wireguard_nt.sh # if windows
go build .
```
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`.
> To test the client GUI application on Windows machines with RDP or vituralized environments (e.g. virtualbox or cloud), you need to download and extract the opengl32.dll from https://fdossena.com/?p=mesa/index.frag next to the built application.
To start NetBird the client in the foreground:
```
@@ -228,42 +185,6 @@ To start NetBird the management service:
./management management --log-level debug --log-file console --config ./management.json
```
#### Windows Netbird Installer
Create dist directory
```shell
mkdir -p dist/netbird_windows_amd64
```
UI client
```shell
CC=x86_64-w64-mingw32-gcc CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -o netbird-ui.exe -ldflags "-s -w -H windowsgui" ./client/ui
mv netbird-ui.exe ./dist/netbird_windows_amd64/
```
Client
```shell
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o netbird.exe ./client/
mv netbird.exe ./dist/netbird_windows_amd64/
```
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to `./dist/netbird_windows_amd64/`.
NSIS compiler
- [Windows-nsis]( https://nsis.sourceforge.io/Download)
- [MacOS-makensis](https://formulae.brew.sh/formula/makensis#default)
- [Linux-makensis](https://manpages.ubuntu.com/manpages/trusty/man1/makensis.1.html)
NSIS Plugins. Download and move them to the NSIS plugins folder.
- [EnVar](https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip)
- [ShellExecAsUser](https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z)
Windows Installer
```shell
export APPVER=0.0.0.1
makensis -V4 client/installer.nsis
```
The installer `netbird-installer.exe` will be created in root directory.
### Test suite
The tests can be started via:
@@ -294,4 +215,4 @@ NetBird project is composed of 3 main repositories:
That we do not have any potential problems later it is sadly necessary to sign a [Contributor License Agreement](CONTRIBUTOR_LICENSE_AGREEMENT.md). That can be done literally with the push of a button.
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.

View File

@@ -1,6 +1,6 @@
<p align="center">
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
<strong>:hatching_chick: New Release! Peer expiration.</strong>
<a href="https://github.com/netbirdio/netbird/releases">
Learn more
</a>
</p>
@@ -24,7 +24,7 @@
<p align="center">
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
@@ -36,62 +36,48 @@
<br>
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**NetBird is an open-source VPN management platform built on top of WireGuard® making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
It requires zero configuration effort leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
NetBird uses [NAT traversal techniques](https://en.wikipedia.org/wiki/Interactive_Connectivity_Establishment) to automatically create an overlay peer-to-peer network connecting machines regardless of location (home, office, data center, container, cloud, or edge environments), unifying virtual private network management experience.
**Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes.
- \[x] Private DNS.
- \[x] Network Activity Monitoring.
- \[x] Mobile clients (Android).
-
**Coming soon:**
- \[ ] Mobile clients (iOS).
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Key features
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
| Connectivity | Management | Automation | Platforms |
|---------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
| <ul><li> - \[x] Post-quantum-secure connection through [Rosenpass](https://rosenpass.eu) </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
| | <ul><li> - \[x] SSH access management </ul></li> | | |
### Start using NetBird
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
- Web UI [repository](https://github.com/netbirdio/dashboard).
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
- Check NetBird [admin UI](https://app.netbird.io/).
- Add more machines.
### Quickstart with self-hosted NetBird
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
### A bit on NetBird internals
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
@@ -103,18 +89,18 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Roadmap
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
### Community projects
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
- [NetBird installer script](https://github.com/physk/netbird-installer)
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
@@ -122,7 +108,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), and [Coturn](https://github.com/coturn/coturn). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -18,9 +18,10 @@ func Encode(num uint32) string {
}
var encoded strings.Builder
remainder := uint32(0)
for num > 0 {
remainder := num % base
remainder = num % base
encoded.WriteByte(alphabet[remainder])
num /= base
}

View File

@@ -1,5 +1,7 @@
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
FROM gcr.io/distroless/base:debug
ENV NB_FOREGROUND_MODE=true
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
SHELL ["/busybox/sh","-c"]
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird

View File

@@ -8,8 +8,8 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
@@ -31,9 +31,9 @@ type IFaceDiscover interface {
stdnet.ExternalIFaceDiscover
}
// NetworkChangeListener export internal NetworkChangeListener for mobile
type NetworkChangeListener interface {
listener.NetworkChangeListener
// RouteListener export internal RouteListener for mobile
type RouteListener interface {
routemanager.RouteListener
}
// DnsReadyListener export internal dns ReadyListener for mobile
@@ -47,26 +47,27 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
networkChangeListener listener.NetworkChangeListener
cfgFile string
tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
routeListener routemanager.RouteListener
onHostDnsFn func([]string)
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client {
return &Client{
cfgFile: cfgFile,
deviceName: deviceName,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
cfgFile: cfgFile,
deviceName: deviceName,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
routeListener: routeListener,
}
}
@@ -96,31 +97,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
c.ctxCancelLock.Unlock()
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
c.onHostDnsFn = func([]string) {}
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
}
// Stop the internal client and free the resources

View File

@@ -84,14 +84,10 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
supportsSSO = false
err = nil
}
@@ -193,7 +189,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
if err != nil {
return nil, err
}
@@ -205,8 +201,8 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
go urlOpener.Open(flowInfo.VerificationURIComplete)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {

View File

@@ -57,11 +57,11 @@ func TestPreferences_ReadUncommitedValues(t *testing.T) {
p.SetManagementURL(exampleString)
resp, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read management url: %s", err)
t.Fatalf("failed to read managmenet url: %s", err)
}
if resp != exampleString {
t.Errorf("unexpected management url: %s", resp)
t.Errorf("unexpected managemenet url: %s", resp)
}
p.SetPreSharedKey(exampleString)
@@ -102,11 +102,11 @@ func TestPreferences_Commit(t *testing.T) {
resp, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read management url: %s", err)
t.Fatalf("failed to read managmenet url: %s", err)
}
if resp != exampleURL {
t.Errorf("unexpected management url: %s", resp)
t.Errorf("unexpected managemenet url: %s", resp)
}
resp, err = p.GetPreSharedKey()

View File

@@ -3,20 +3,21 @@ package cmd
import (
"context"
"fmt"
"os"
"github.com/netbirdio/netbird/client/internal/auth"
"strings"
"time"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
)
var loginCmd = &cobra.Command{
@@ -51,7 +52,7 @@ var loginCmd = &cobra.Command{
AdminURL: adminURL,
ConfigPath: configPath,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
if preSharedKey != "" {
ic.PreSharedKey = &preSharedKey
}
@@ -60,7 +61,7 @@ var loginCmd = &cobra.Command{
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
@@ -81,11 +82,9 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
}
var loginErr error
@@ -115,7 +114,7 @@ var loginCmd = &cobra.Command{
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
@@ -151,21 +150,13 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
jwtToken = tokenInfo.GetTokenToUse()
}
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
@@ -174,7 +165,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
if err != nil {
return nil, err
}
@@ -186,8 +177,8 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
waitTimeout := time.Duration(flowInfo.ExpiresIn)
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
@@ -200,21 +191,17 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
if userCode != "" {
if !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
}
cmd.Println("Please do the SSO login in your browser. \n" +
err := open.Run(verificationURIComplete)
cmd.Printf("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
" " + verificationURIComplete + " " + codeMsg + " \n\n")
if err != nil {
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
}
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@@ -25,12 +25,8 @@ import (
)
const (
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
)
var (
@@ -53,9 +49,6 @@ var (
preSharedKey string
natExternalIPs []string
customDNSAddress string
rosenpassEnabled bool
interfaceName string
wireguardPort uint16
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
@@ -99,9 +92,9 @@ func init() {
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
@@ -125,7 +118,6 @@ func init() {
`An empty string "" clears the previous configuration. `+
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
}
// SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -66,15 +66,13 @@ type statusOutputOverview struct {
}
var (
detailFlag bool
ipv4Flag bool
jsonFlag bool
yamlFlag bool
ipsFilter []string
prefixNamesFilter []string
statusFilter string
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
detailFlag bool
ipv4Flag bool
jsonFlag bool
yamlFlag bool
ipsFilter []string
statusFilter string
ipsFilterMap map[string]struct{}
)
var statusCmd = &cobra.Command{
@@ -85,14 +83,12 @@ var statusCmd = &cobra.Command{
func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
}
@@ -113,9 +109,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(context.Background())
resp, err := getStatus(ctx, cmd)
resp, _ := getStatus(ctx, cmd)
if err != nil {
return err
return nil
}
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
@@ -124,7 +120,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
" netbird up \n\n"+
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
resp.GetStatus(),
)
return nil
@@ -137,7 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
outputInformationHolder := convertToStatusOutputOverview(resp)
var statusOutputString string
statusOutputString := ""
switch {
case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
@@ -176,12 +172,8 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
}
@@ -193,26 +185,11 @@ func parseFilters() error {
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
}
ipsFilterMap[addr] = struct{}{}
enableDetailFlagWhenFilterFlag()
}
}
if len(prefixNamesFilter) > 0 {
for _, name := range prefixNamesFilter {
prefixNamesFilterMap[strings.ToLower(name)] = struct{}{}
}
enableDetailFlagWhenFilterFlag()
}
return nil
}
func enableDetailFlagWhenFilterFlag() {
if !detailFlag && !jsonFlag && !yamlFlag {
detailFlag = true
}
}
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()
@@ -257,7 +234,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
continue
}
if isPeerConnected {
peersConnected++
peersConnected = peersConnected + 1
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
@@ -430,7 +407,7 @@ func parsePeers(peers peersStateOutput) string {
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
)
peersString += peerString
peersString = peersString + peerString
}
return peersString
}
@@ -438,7 +415,6 @@ func parsePeers(peers peersStateOutput) string {
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := false
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
@@ -455,15 +431,5 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
ipEval = true
}
}
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true
break
}
}
}
return statusEval || ipEval || nameEval
return statusEval || ipEval
}

View File

@@ -22,7 +22,6 @@ import (
)
func startTestingServices(t *testing.T) string {
t.Helper()
config := &mgmt.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
@@ -45,7 +44,6 @@ func startTestingServices(t *testing.T) string {
}
func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
@@ -62,29 +60,28 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
}
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
store, err := mgmt.NewFileStore(config.Datadir, nil)
if err != nil {
t.Fatal(err)
}
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
peersUpdateManager := mgmt.NewPeersUpdateManager()
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore, false)
eventStore)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
if err != nil {
t.Fatal(err)
}
@@ -101,7 +98,6 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
func startClientDaemon(
t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
log "github.com/sirupsen/logrus"
@@ -17,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
@@ -38,8 +36,6 @@ var (
func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -89,24 +85,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
if preSharedKey != "" {
ic.PreSharedKey = &preSharedKey
}
@@ -115,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
@@ -144,7 +123,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
log.Warnf("failed closing dameon gRPC client connection %v", err)
return
}
}()
@@ -162,31 +141,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
}
loginRequest := proto.LoginRequest{
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
}
if cmd.Flag(enableRosenpassFlag).Changed {
loginRequest.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
}
loginRequest.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
wp := int64(wireguardPort)
loginRequest.WireguardPort = &wp
SetupKey: setupKey,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
}
var loginErr error
@@ -217,7 +178,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
@@ -238,11 +199,11 @@ func validateNATExternalIPs(list []string) error {
subElements := strings.Split(element, "/")
if len(subElements) > 2 {
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"String\" or \"String/String\"", element, externalIPMapFlag)
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag)
}
if len(subElements) == 1 && !isValidIP(subElements[0]) {
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
}
last := 0
@@ -260,18 +221,6 @@ func validateNATExternalIPs(list []string) error {
return nil
}
func parseInterfaceName(name string) error {
if runtime.GOOS != "darwin" {
return nil
}
if strings.HasPrefix(name, "utun") {
return nil
}
return fmt.Errorf("invalid interface name %s. Please use the prefix utun followed by a number on MacOS. e.g., utun1 or utun199", name)
}
func validateElement(element string) (int, error) {
if isValidIP(element) {
return ipInputType, nil
@@ -309,7 +258,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
var parsed []byte
if modified {
if !isValidAddrPort(customDNSAddress) {
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress)
}
if customDNSAddress == "" && logFile != "console" {
parsed = []byte("empty")

View File

@@ -1,32 +0,0 @@
//go:build !linux || android
package firewall
import (
"context"
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}
err = fm.AllowNetbird()
if err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}

View File

@@ -1,107 +0,0 @@
//go:build !android
package firewall
import (
"context"
"fmt"
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var errFw error
switch check() {
case IPTABLES:
log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
}
if iface.IsUserspaceBind() {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
if errFw != nil {
return nil, errFw
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -1,17 +1,9 @@
package manager
package firewall
import (
"fmt"
"net"
)
const (
NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s"
InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
)
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
@@ -35,8 +27,10 @@ const (
type Action int
const (
// ActionUnknown is a unknown action
ActionUnknown Action = iota
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
ActionAccept
// ActionDrop is the action to drop a packet
ActionDrop
)
@@ -46,9 +40,6 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
@@ -62,27 +53,16 @@ type Manager interface {
action Action,
ipsetName string,
comment string,
) ([]Rule, error)
) (Rule, error)
// DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair RouterPair) error
// Reset firewall to the default state
Reset() error
// Flush the changes to firewall controller
Flush() error
}
func GenKey(format string, input string) string {
return fmt.Sprintf(format, input)
// TODO: migrate routemanager firewal actions to this interface
}

View File

@@ -1,11 +0,0 @@
package firewall
import "github.com/netbirdio/netbird/iface"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}

View File

@@ -1,473 +0,0 @@
package iptables
import (
"fmt"
"net"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
)
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routeingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
}
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
}
func (m *aclManager) AddFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
var chain string
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
}}, nil
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
}
if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil
}
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
}
// DeleteRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
}
DELETERULE:
var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
}
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
return err
}
func (m *aclManager) Reset() error {
return m.cleanChains()
}
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
specs = append(src, specs...)
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries {
for _, rule := range rules {
if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
}
return nil
}
func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD",
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("PREROUTING",
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
switch direction {
case firewall.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != "":
return ipsetName + "-sport"
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -1,105 +1,262 @@
package iptables
import (
"context"
"fmt"
"net"
"strconv"
"sync"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
)
const (
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
ChainInputFilterName = "NETBIRD-ACL-INPUT"
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
)
// dropAllDefaultRule in the Netbird chain
var dropAllDefaultRule = []string{"-j", "DROP"}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *routerManager
ipv6Client *iptables.IPTables
inputDefaultRuleSpecs []string
outputDefaultRuleSpecs []string
wgIface iFaceMapper
rulesets map[string]ruleset
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
type ruleset struct {
rule *Rule
ips map[string]string
}
// Create iptables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
wgIface: wgIface,
inputDefaultRuleSpecs: []string{
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
outputDefaultRuleSpecs: []string{
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
rulesets: make(map[string]ruleset),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
// init clients for booth ipv4 and ipv6
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
m := &Manager{
wgIface: wgIface,
ipv4Client: iptablesClient,
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client
}
m.router, err = newRouterManager(context, iptablesClient)
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
}
if err := m.Reset(); err != nil {
return nil, fmt.Errorf("failed to reset firewall: %v", err)
}
return m, nil
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
// If comment is empty rule ID is used as comment
func (m *Manager) AddFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
protocol fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) (fw.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
client, err := m.client(ip)
if err != nil {
return nil, err
}
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String()
if comment == "" {
comment = ruleID
}
if ipsetName != "" {
rs, rsExists := m.rulesets[ipsetName]
if !rsExists {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// this is new ipset so we need to create firewall rule for it
}
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
direction, action, comment, ipsetName)
if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is output rule already exists: %w", err)
}
if ok {
return nil, fmt.Errorf("input rule already exists")
}
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
return nil, err
}
} else {
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is input rule already exists: %w", err)
}
if ok {
return nil, fmt.Errorf("input rule already exists")
}
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
return nil, err
}
}
rule := &Rule{
ruleID: ruleID,
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
// for it, need to assosiate it with ruleset
m.rulesets[ipsetName] = ruleset{
rule: rule,
ips: map[string]string{rule.ip: ruleID},
}
}
return rule, nil
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeleteRule(rule)
}
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
client := m.ipv4Client
if r.v6 {
if m.ipv6Client == nil {
return fmt.Errorf("ipv6 is not supported")
}
client = m.ipv6Client
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if rs, ok := m.rulesets[r.ipsetName]; ok {
// delete IP from ruleset IPs list and ipset
if _, ok := rs.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(rs.ips, r.ip)
}
return m.router.InsertRoutingRules(pair)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(rs.ips) != 0 {
return nil
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// we delete last IP from the set, that means we need to delete
// set itself and assosiated firewall rule too
delete(m.rulesets, r.ipsetName)
return m.router.RemoveRoutingRules(pair)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
r = rs.rule
}
if r.dst {
return client.Delete("filter", ChainOutputFilterName, r.specs...)
}
return client.Delete("filter", ChainInputFilterName, r.specs...)
}
// Reset firewall to the default state
@@ -107,49 +264,192 @@ func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
errAcl := m.aclMgr.Reset()
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
if err := m.reset(m.ipv4Client, "filter"); err != nil {
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
}
errMgr := m.router.Reset()
if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
}
return errAcl
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
if m.ipv6Client != nil {
if err := m.reset(m.ipv6Client, "filter"); err != nil {
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
}
}
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionIN,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName)
if err != nil {
return fmt.Errorf("failed to check if input chain exists: %w", err)
}
if ok {
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
return err
} else if ok {
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
}
}
}
ok, err = client.ChainExists(table, ChainOutputFilterName)
if err != nil {
return fmt.Errorf("failed to check if output chain exists: %w", err)
}
if ok {
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
return err
} else if ok {
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
}
}
}
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
log.Errorf("failed to clear and delete input chain: %v", err)
return nil
}
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
log.Errorf("failed to clear and delete input chain: %v", err)
return nil
}
for ipsetName := range m.rulesets {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
delete(m.rulesets, ipsetName)
}
return nil
}
// filterRuleSpecs returns the specs of a filtering rule
func (m *Manager) filterRuleSpecs(
table string, ip net.IP, protocol string, sPort, dPort string,
direction fw.RuleDirection, action fw.Action, comment string,
ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if s := ip.String(); s == "0.0.0.0" || s == "::" {
matchByIP = false
}
switch direction {
case fw.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case fw.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
specs = append(specs, "-j", m.actionToStr(action))
return append(specs, "-m", "comment", "--comment", comment)
}
// rawClient returns corresponding iptables client for the given ip
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
if ip.To4() != nil {
return m.ipv4Client, nil
}
if m.ipv6Client == nil {
return nil, fmt.Errorf("ipv6 is not supported")
}
return m.ipv6Client, nil
}
// client returns client with initialized chain and default rules
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
client, err := m.rawClient(ip)
if err != nil {
return nil, err
}
ok, err := client.ChainExists("filter", ChainInputFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
}
if !ok {
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
return nil, fmt.Errorf("failed to create input chain: %w", err)
}
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
}
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
}
}
ok, err = client.ChainExists("filter", ChainOutputFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
}
if !ok {
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
return nil, fmt.Errorf("failed to create output chain: %w", err)
}
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
}
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
}
}
return client, nil
}
func (m *Manager) actionToStr(action fw.Action) string {
if action == fw.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
if ipsetName == "" {
return ""
} else if sPort != "" && dPort != "" {
return ipsetName + "-sport-dport"
} else if sPort != "" {
return ipsetName + "-sport"
} else if dPort != "" {
return ipsetName + "-dport"
}
return ipsetName
}

View File

@@ -1,7 +1,6 @@
package iptables
import (
"context"
"fmt"
"net"
"testing"
@@ -10,7 +9,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
)
@@ -34,8 +33,6 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
@@ -56,7 +53,7 @@ func TestIptablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -68,20 +65,17 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
var rule1 fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
})
var rule2 []fw.Rule
var rule2 fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
@@ -91,28 +85,21 @@ func TestIptablesManager(t *testing.T) {
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
@@ -125,11 +112,11 @@ func TestIptablesManager(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
}
})
}
@@ -154,7 +141,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -166,7 +153,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
var rule1 fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
@@ -176,14 +163,12 @@ func TestIptablesManagerIPSet(t *testing.T) {
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
})
var rule2 []fw.Rule
var rule2 fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
@@ -193,29 +178,23 @@ func TestIptablesManagerIPSet(t *testing.T) {
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
}
require.NoError(t, err, "failed to add rule")
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
}
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
@@ -225,7 +204,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
t.Helper()
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
require.NoError(t, err, "failed to check rule")
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
@@ -251,7 +229,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -262,6 +240,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
}()
_, err = manager.client(net.ParseIP("10.20.0.100"))
require.NoError(t, err)
ip := net.ParseIP("10.20.0.100")

View File

@@ -1,340 +0,0 @@
//go:build !android
package iptables
import (
"context"
"fmt"
"strings"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
)
type routerManager struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules map[string][]string
}
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
m := &routerManager{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
}
err := m.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to cleanup routing rules: %s", err)
return nil, err
}
err = m.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return m, err
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
}
delete(i.rules, ruleKey)
return nil
}
func (i *routerManager) RouteingFwChainName() string {
return chainRTFWD
}
func (i *routerManager) Reset() error {
err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
}
i.rules = make(map[string][]string)
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
return err
}
}
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err
}
}
return nil
}
func (i *routerManager) createContainers() error {
if i.rules[Ipv4Forwarding] != nil {
return nil
}
errMSGFormat := "failed creating chain %s,error: %v"
err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
}
err = i.createChain(tableNat, chainRTNAT)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *routerManager) addJumpRules() error {
rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
}
i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
}
}
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
}
return ruleType
}

View File

@@ -1,229 +0,0 @@
//go:build !android
package iptables
import (
"context"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
)
func isIptablesSupported() bool {
_, err4 := exec.LookPath("iptables")
return err4 == nil
}
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager")
defer func() {
_ = manager.Reset()
}()
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
Source: "100.100.100.1/32",
Destination: "100.100.100.0/24",
Masquerade: true,
}
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
defer func() {
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}()
err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
defer func() {
_ = manager.Reset()
}()
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}

View File

@@ -7,7 +7,8 @@ type Rule struct {
specs []string
ip string
chain string
dst bool
v6 bool
}
// GetRuleID returns the rule id

View File

@@ -1,50 +0,0 @@
package iptables
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}

View File

@@ -1,18 +0,0 @@
package manager
type RouterPair struct {
ID string
Source string
Destination string
Masquerade bool
}
func GetInPair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,85 +0,0 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -2,52 +2,88 @@ package nftables
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
)
const (
// tableName is the name of the table that is used for filtering by the Netbird client
tableName = "netbird"
// FilterTableName is the name of the table that is used for filtering by the Netbird client
FilterTableName = "netbird-acl"
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
FilterInputChainName = "netbird-acl-input-filter"
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
)
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
wgIface iFaceMapper
mutex sync.Mutex
router *router
aclManager *AclManager
rConn *nftables.Conn
sConn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
filterInputChainIPv4 *nftables.Chain
filterOutputChainIPv4 *nftables.Chain
filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
wgIface iFaceMapper
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
}
// Create nftables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
m := &Manager{
rConn: &nftables.Conn{},
rConn: &nftables.Conn{},
sConn: sConn,
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
wgIface: wgIface,
}
workTable, err := m.createWorkTable()
if err != nil {
return nil, err
}
m.router, err = newRouter(context, workTable)
if err != nil {
return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
if err := m.Reset(); err != nil {
return nil, err
}
@@ -60,98 +96,533 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) (fw.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var (
err error
ipset *nftables.Set
table *nftables.Table
chain *nftables.Chain
)
if direction == fw.RuleDirectionOUT {
table, chain, err = m.chain(
ip,
FilterOutputChainName,
nftables.ChainHookOutput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
} else {
table, chain, err = m.chain(
ip,
FilterInputChainName,
nftables.ChainHookInput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
}
if err != nil {
return nil, err
}
rawIP := ip.To4()
if rawIP == nil {
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
rawIP = ip.To16()
}
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
return m.aclManager.DeleteRule(rule)
}
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
return m.router.InsertRoutingRules(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveRoutingRules(pair)
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock()
defer m.mutex.Unlock()
err := m.aclManager.createDefaultAllowRules()
if err != nil {
return fmt.Errorf("failed to create default allow rules: %v", err)
}
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != "all" {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
})
var protoData []byte
switch proto {
case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
})
}
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
// in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrLen := uint32(len(rawIP))
addrOffset := uint32(12)
if addrLen == 16 {
addrOffset = 8
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
addrOffset += addrLen
}
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset,
Len: addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
}
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
if action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{
Table: table,
Chain: chain,
Position: 0,
Exprs: expressions,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP)
}
// getRulesetID returns ruleset ID based on given parameters
func (m *Manager) getRulesetID(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
}
// createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
// chain returns the chain for the given IP address with specific settings
func (m *Manager) chain(
ip net.IP,
name string,
hook nftables.ChainHook,
priority nftables.ChainPriority,
cType nftables.ChainType,
) (*nftables.Table, *nftables.Chain, error) {
var err error
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
if c != nil {
return c, nil
}
return m.createChainIfNotExists(tf, name, hook, priority, cType)
}
if ip.To4() != nil {
if name == FilterInputChainName {
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterInputChainIPv4, err
}
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterOutputChainIPv4, err
}
if name == FilterInputChainName {
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterInputChainIPv6, err
}
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterOutputChainIPv6, err
}
// table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
}
m.tableIPv4 = table
return m.tableIPv4, nil
}
if m.tableIPv6 != nil {
return m.tableIPv6, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
m.tableIPv6 = table
return m.tableIPv6, nil
}
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == FilterTableName {
return t, nil
}
}
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
}
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family)
if err != nil {
return nil, err
}
chains, err := m.rConn.ListChainsOfTableFamily(family)
if err != nil {
return nil, fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
if c.Name == name && c.Table.Name == table.Name {
return c, nil
}
}
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: table,
Hooknum: hooknum,
Priority: priority,
Type: chainType,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0
if name == FilterOutputChainName {
ifaceKey = expr.MetaKeyOIFNAME
shiftDSTAddr = 1
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
if m.wgIface.Address().IP.To4() == nil {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
expressions = append(expressions,
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(8 + (16 * shiftDSTAddr)),
Len: 16,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 16,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: mask.Unmap().AsSlice(),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
} else {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions = append(expressions,
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12 + (4 * shiftDSTAddr)),
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
expressions = []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return chain, nil
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if nativeRule.nftRule == nil {
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
m.applyAllowNetbirdRules(chain)
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return err
}
nativeRule.nftRule = nil
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
@@ -166,33 +637,18 @@ func (m *Manager) Reset() error {
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c)
}
}
m.router.ResetForwardRules()
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
if t.Name == FilterTableName {
m.rConn.DelTable(t)
}
}
@@ -203,65 +659,100 @@ func (m *Manager) Reset() error {
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
// todo review this method usage
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.Flush()
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
if err := m.flushWithBackoff(); err != nil {
return err
}
for _, t := range tables {
if t.Name == tableName {
m.rConn.DelTable(t)
// set must be removed after flush rule changes
// otherwise we will get error
for _, s := range m.setRemoved {
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
if len(m.setRemoved) > 0 {
if err := m.flushWithBackoff(); err != nil {
return err
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
_ = m.rConn.InsertRule(rule)
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime = backoffTime * 2
continue
}
break
}
return
}
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
if table == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(table, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) != 0 {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
log.Errorf("failed to set rule handle: %v", err)
}
}
}
return nil
}
func encodePort(port fw.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
return b
}

View File

@@ -1,7 +1,6 @@
package nftables
import (
"context"
"fmt"
"net"
"net/netip"
@@ -13,7 +12,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
)
@@ -37,8 +36,6 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
@@ -56,7 +53,7 @@ func TestNftablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
@@ -85,10 +82,14 @@ func TestNftablesManager(t *testing.T) {
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 1, "expected 1 rules")
// test expectations:
// 1) regular rule
// 2) "accept extra routed traffic rule" for the interface
// 3) "drop all rule" for the interface
require.Len(t, rules, 3, "expected 3 rules")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
@@ -136,17 +137,18 @@ func TestNftablesManager(t *testing.T) {
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule {
err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.DeleteRule(rule)
require.NoError(t, err, "failed to delete rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 0, "expected 0 rules after deletion")
// test expectations:
// 1) "accept extra routed traffic rule" for the interface
// 2) "drop all rule" for the interface
require.Len(t, rules, 2, "expected 2 rules after deleteion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
@@ -171,7 +173,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(context.Background(), mock)
manager, err := Create(mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)

View File

@@ -1,413 +0,0 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -1,280 +0,0 @@
//go:build !android
package nftables
import (
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.ResetForwardRules()
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}
})
}
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func createWorkTable() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
}
func deleteWorkTable() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)
}
}
}

View File

@@ -1,8 +1,6 @@
package nftables
import (
"net"
"github.com/google/nftables"
)
@@ -10,8 +8,9 @@ import (
type Rule struct {
nftRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
ruleID string
ip []byte
}
// GetRuleID returns the rule id

View File

@@ -0,0 +1,115 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@@ -0,0 +1,122 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@@ -1,4 +1,4 @@
package manager
package firewall
import (
"strconv"

View File

@@ -1,47 +0,0 @@
//go:build !android
package test
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
var (
InsertRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
}{
{
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: false,
},
},
{
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
}
RemoveRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
IpVersion string
}{
{
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
}
)

View File

@@ -1,25 +0,0 @@
//go:build !windows
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset()
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
return nil
}

View File

@@ -1,94 +0,0 @@
package uspfilter
import (
"fmt"
"os/exec"
"syscall"
log "github.com/sirupsen/logrus"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if !isWindowsFirewallReachable() {
return nil
}
if !isFirewallRuleActive(firewallRuleName) {
return nil
}
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !isWindowsFirewallReachable() {
return nil
}
if isFirewallRuleActive(firewallRuleName) {
return nil
}
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
if action == addRule {
args = append(args, extraArgs...)
}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
func isWindowsFirewallReachable() bool {
args := []string{"advfirewall", "show", "allprofiles", "state"}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output()
if err != nil {
log.Infof("Windows firewall is not reachable, skipping default rule management. Using only user space rules. Error: %s", err)
return false
}
return true
}
func isFirewallRuleActive(ruleName string) bool {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output()
return err == nil
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
)
// Rule to handle management of rules
@@ -15,7 +15,7 @@ type Rule struct {
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
direction firewall.RuleDirection
direction fw.RuleDirection
sPort uint16
dPort uint16
drop bool

View File

@@ -10,20 +10,15 @@ import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
)
const layerTypeAll = 0
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key
@@ -31,12 +26,10 @@ type RuleSet map[string]Rule
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
nativeFirewall firewall.Manager
outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
mutex sync.RWMutex
}
@@ -56,20 +49,6 @@ type decoder struct {
// Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) {
return create(iface)
}
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
mgr, err := create(iface)
if err != nil {
return nil, err
}
mgr.nativeFirewall = nativeFirewall
return mgr, nil
}
func create(iface IFaceMapper) (*Manager, error) {
m := &Manager{
decoders: sync.Pool{
New: func() any {
@@ -86,7 +65,6 @@ func create(iface IFaceMapper) (*Manager, error) {
},
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
}
if err := iface.SetFilter(m); err != nil {
@@ -95,50 +73,27 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil
}
func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil {
return false
} else {
return true
}
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.InsertRoutingRules(pair)
}
// RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.RemoveRoutingRules(pair)
}
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) (fw.Rule, error) {
r := Rule{
id: uuid.New().String(),
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
direction: direction,
drop: action == firewall.ActionDrop,
drop: action == fw.ActionDrop,
comment: comment,
}
if ipNormalized := ip.To4(); ipNormalized != nil {
@@ -159,21 +114,21 @@ func (m *Manager) AddFiltering(
}
switch proto {
case firewall.ProtocolTCP:
case fw.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
case fw.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
case fw.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
case fw.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock()
if direction == firewall.RuleDirectionIN {
if direction == fw.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
@@ -185,11 +140,12 @@ func (m *Manager) AddFiltering(
m.outgoingRules[r.ip.String()][r.id] = r
}
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
return &r, nil
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -198,7 +154,7 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if r.direction == firewall.RuleDirectionIN {
if r.direction == fw.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
@@ -215,6 +171,17 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@@ -228,7 +195,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules, true)
}
// dropFilter implements same logic for booth direction of the traffic
// dropFilter imlements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
@@ -362,7 +329,7 @@ func (m *Manager) AddUDPPacketHook(
protoLayer: layers.LayerTypeUDP,
dPort: dPort,
ipLayer: layers.LayerTypeIPv6,
direction: firewall.RuleDirectionOUT,
direction: fw.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook,
}
@@ -373,7 +340,7 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock()
if in {
r.direction = firewall.RuleDirectionIN
r.direction = fw.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
@@ -395,16 +362,14 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeleteRule(&rule)
return m.DeleteRule(&r)
}
}
}
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeleteRule(&rule)
return m.DeleteRule(&r)
}
}
}

View File

@@ -10,13 +10,12 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
)
type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
}
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@@ -26,13 +25,6 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
@@ -125,32 +117,24 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
for _, r := range rule {
err = m.DeleteRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
err = m.DeleteRule(rule)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
for _, r := range rule2 {
err = m.DeleteRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
err = m.DeleteRule(rule2)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}

View File

@@ -166,9 +166,10 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR"
SetShellVarContext all
SetShellVarContext current
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SetShellVarContext all
SectionEnd
Section -Post
@@ -193,12 +194,12 @@ Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll"
RmDir /r "$INSTDIR"
SetShellVarContext all
SetShellVarContext current
Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
SetShellVarContext all
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
@@ -208,7 +209,8 @@ SectionEnd
Function LaunchLink
SetShellVarContext all
SetShellVarContext current
SetOutPath $INSTDIR
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
SetShellVarContext all
FunctionEnd

View File

@@ -11,34 +11,49 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
Stop()
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
manager firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
}
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock()
defer d.mutex.Unlock()
@@ -54,13 +69,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
time.Since(start), total)
}()
if d.firewall == nil {
if d.manager == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
defer func() {
if err := d.firewall.Flush(); err != nil {
if err := d.manager.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
@@ -110,35 +125,58 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
)
}
applyFailed := false
newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules {
selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" {
d.ipsetCounter++
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
}
ipsetName = ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
d.rollBack(newRulePairs)
applyFailed = true
break
}
if len(rules) > 0 {
d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
if applyFailed {
log.Error("failed to apply firewall rules, rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
continue
}
}
}
return
}
for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
if err := d.manager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
@@ -149,6 +187,16 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.rulesPairs = newRulePairs
}
// Stop ACL controller and clear firewall state
func (d *DefaultManager) Stop() {
d.mutex.Lock()
defer d.mutex.Unlock()
if err := d.manager.Reset(); err != nil {
log.WithError(err).Error("reset firewall state")
}
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
@@ -158,14 +206,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
protocol := convertToFirewallProtocol(r.Protocol)
if protocol == firewall.ProtocolUnknown {
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
}
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
action := convertFirewallAction(r.Action)
if action == firewall.ActionUnknown {
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
}
var port *firewall.Port
@@ -185,6 +233,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
var rules []firewall.Rule
var err error
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
@@ -198,6 +247,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, err
}
d.rulesPairs[ruleID] = rules
return ruleID, rules, nil
}
@@ -210,24 +260,24 @@ func (d *DefaultManager) addInRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddFiltering(
rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
rules = append(rules, rule)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.firewall.AddFiltering(
rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
return append(rules, rule), nil
}
func (d *DefaultManager) addOutRules(
@@ -239,24 +289,24 @@ func (d *DefaultManager) addOutRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddFiltering(
rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
rules = append(rules, rule)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.firewall.AddFiltering(
rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
return append(rules, rule), nil
}
// getRuleID() returns unique ID for the rule based on its parameters.
@@ -317,7 +367,7 @@ func (d *DefaultManager) squashAcceptRules(
protocols[r.Protocol] = map[string]int{}
}
// special case, when we receive this all network IP address
// special case, when we recieve this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
@@ -344,7 +394,7 @@ func (d *DefaultManager) squashAcceptRules(
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
// only for ther first element ALL, it must be done first
protocolOrders := []mgmProto.FirewallRuleProtocol{
mgmProto.FirewallRule_ALL,
mgmProto.FirewallRule_ICMP,
@@ -412,29 +462,18 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
switch protocol {
case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP, nil
return firewall.ProtocolTCP
case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP, nil
return firewall.ProtocolUDP
case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP, nil
return firewall.ProtocolICMP
case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL, nil
return firewall.ProtocolALL
default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
return firewall.ProtocolUnknown
}
}
@@ -442,13 +481,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
switch action {
case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept, nil
return firewall.ActionAccept
case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop, nil
return firewall.ActionDrop
default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
return firewall.ActionUnknown
}
}

View File

@@ -0,0 +1,23 @@
//go:build !linux
package acl
import (
"fmt"
"runtime"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// Create creates a firewall manager instance
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}
return newDefaultManager(fm), nil
}
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -0,0 +1,33 @@
package acl
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
var fm firewall.Manager
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
if fm, err = uspfilter.Create(iface); err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err
}
} else {
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager: %s", err)
// fallback to iptables
if fm, err = iptables.Create(iface); err != nil {
log.Errorf("failed to create iptables manager: %s", err)
return nil, err
}
}
}
return newDefaultManager(fm), nil
}

View File

@@ -1,16 +1,11 @@
package acl
import (
"context"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -37,30 +32,18 @@ func TestDefaultManager(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFilter(gomock.Any())
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
acl, err := Create(iface)
if err != nil {
t.Errorf("create firewall: %v", err)
t.Errorf("create ACL manager: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
defer acl.Stop()
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
@@ -195,33 +178,31 @@ func TestDefaultManagerSquashRules(t *testing.T) {
}
r := rules[0]
switch {
case r.PeerIP != "0.0.0.0":
if r.PeerIP != "0.0.0.0" {
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.FirewallRule_IN:
} else if r.Direction != mgmProto.FirewallRule_IN {
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.FirewallRule_ALL:
} else if r.Protocol != mgmProto.FirewallRule_ALL {
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.FirewallRule_ACCEPT:
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1]
switch {
case r.PeerIP != "0.0.0.0":
if r.PeerIP != "0.0.0.0" {
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.FirewallRule_OUT:
} else if r.Direction != mgmProto.FirewallRule_OUT {
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.FirewallRule_ALL:
} else if r.Protocol != mgmProto.FirewallRule_ALL {
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.FirewallRule_ACCEPT:
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -289,7 +270,7 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
manager := &DefaultManager{}
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
t.Errorf("we should got same amount of rules as intput, got %v", len(rules))
}
}
@@ -330,30 +311,18 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFilter(gomock.Any())
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
acl, err := Create(iface)
if err != nil {
t.Errorf("create firewall: %v", err)
t.Errorf("create ACL manager: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
defer acl.Stop()
acl.ApplyFiltering(networkMap)

View File

@@ -4,13 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/client/internal"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/netbirdio/netbird/client/internal"
)
// HostedGrantType grant type for device flow on Hosted
@@ -175,7 +174,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
if tokenResponse.Error == "authorization_pending" {
continue
} else if tokenResponse.Error == "slow_down" {
interval += (3 * time.Second)
interval = interval + (3 * time.Second)
ticker.Reset(interval)
continue
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net/http"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
@@ -58,52 +57,34 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
log.Debug("getting device authorization flow info")
// Try to initialize the Device Authorization Flow
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err == nil {
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config)
}
return pkceFlow, nil
}
log.Debugf("getting device authorization flow info failed with error: %v", err)
log.Debugf("falling back to pkce authorization flow info")
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
// If Device Authorization Flow failed, try the PKCE Authorization Flow
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
s, ok := gstatus.FromError(err)
if ok && s.Code() == codes.NotFound {
return nil, fmt.Errorf("no SSO provider returned from management. " +
"Please proceed with setting up this device using setup keys " +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
case ok && s.Code() == codes.Unimplemented:
"If you are using hosting Netbird see documentation at " +
"https://github.com/netbirdio/netbird/tree/main/management for details")
} else if ok && s.Code() == codes.Unimplemented {
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
"please update your server or use Setup Keys to login", config.ManagementURL)
default:
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
} else {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"html/template"
"net"
@@ -26,8 +25,6 @@ var _ OAuthFlow = &PKCEAuthorizationFlow{}
const (
queryState = "state"
queryCode = "code"
queryError = "error"
queryErrorDesc = "error_description"
defaultPKCETimeoutSeconds = 300
)
@@ -79,7 +76,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
}
// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
state, err := randomBytesInHex(24)
if err != nil {
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
@@ -113,79 +110,71 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
}
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("failed to close the server: %v", err)
}
}()
go p.startServer(server, tokenChan, errChan)
go p.startServer(tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
return p.handleOAuthToken(token)
case err := <-errChan:
return TokenInfo{}, err
}
}
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
return
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
if err != nil {
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
return
}
port := parsedURL.Port()
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
defer func() {
if err := server.Shutdown(context.Background()); err != nil {
log.Errorf("error while shutting down pkce flow server: %v", err)
}
}()
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
tokenValidatorFunc := func() (*oauth2.Token, error) {
query := req.URL.Query()
state := query.Get(queryState)
// Prevent timing attacks on state
if subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
token, err := tokenValidatorFunc()
if err != nil {
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
renderPKCEFlowTmpl(w, err)
}
renderPKCEFlowTmpl(w, nil)
tokenChan <- token
renderPKCEFlowTmpl(w, nil)
})
server.Handler = mux
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
if err := server.ListenAndServe(); err != nil {
errChan <- err
}
}
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
query := req.URL.Query()
if authError := query.Get(queryError); authError != "" {
authErrorDesc := query.Get(queryErrorDesc)
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
}
// Prevent timing attacks on the state
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
return nil, fmt.Errorf("invalid state")
}
code := query.Get(queryCode)
if code == "" {
return nil, fmt.Errorf("missing code")
}
return p.oAuthConfig.Exchange(
req.Context(),
code,
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
)
}
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
tokenInfo := TokenInfo{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
@@ -197,13 +186,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
tokenInfo.IDToken = idToken
}
// if a provider doesn't support an audience, use the Client ID for token verification
audience := p.providerConfig.Audience
if audience == "" {
audience = p.providerConfig.ClientID
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"reflect"
"strings"
)
@@ -43,14 +44,15 @@ func isValidAccessToken(token string, audience string) error {
}
// Audience claim of JWT can be a string or an array of strings
switch aud := claims.Audience.(type) {
case string:
if aud == audience {
typ := reflect.TypeOf(claims.Audience)
switch typ.Kind() {
case reflect.String:
if claims.Audience == audience {
return nil
}
case []interface{}:
for _, audItem := range aud {
if audStr, ok := audItem.(string); ok && audStr == audience {
case reflect.Slice:
for _, aud := range claims.Audience.([]interface{}) {
if audience == aud {
return nil
}
}

View File

@@ -1,7 +1,6 @@
package internal
import (
"context"
"fmt"
"net/url"
"os"
@@ -13,19 +12,16 @@ import (
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/util"
)
const (
// managementLegacyPortString is the port that was used before by the Management gRPC server.
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
managementLegacyPortString = "33073"
ManagementLegacyPort = 33073
// DefaultManagementURL points to the NetBird's cloud management endpoint
DefaultManagementURL = "https://api.netbird.io:443"
// oldDefaultManagementURL points to the NetBird's old cloud management endpoint
oldDefaultManagementURL = "https://api.wiretrustee.com:443"
DefaultManagementURL = "https://api.wiretrustee.com:443"
// DefaultAdminURL points to NetBird's cloud management console
DefaultAdminURL = "https://app.netbird.io:443"
)
@@ -41,9 +37,6 @@ type ConfigInput struct {
PreSharedKey *string
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
InterfaceName *string
WireguardPort *int
}
// Config Configuration type
@@ -57,11 +50,10 @@ type Config struct {
WgPort int
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
// SSHKey is a private SSH key in a PEM format
SSHKey string
// ExternalIP mappings, if different from the host interface IP
// ExternalIP mappings, if different than the host interface IP
//
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
@@ -144,10 +136,11 @@ func createNewConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
WgIface: iface.WgInterfaceDefault,
WgPort: iface.DefaultWgPort,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
@@ -168,24 +161,10 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config.ManagementURL = URL
}
config.WgPort = iface.DefaultWgPort
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
}
config.WgIface = iface.WgInterfaceDefault
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
}
if input.PreSharedKey != nil {
config.PreSharedKey = *input.PreSharedKey
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
}
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return nil, err
@@ -236,9 +215,12 @@ func update(input ConfigInput) (*Config, error) {
}
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
if *input.PreSharedKey != "" {
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
*input.PreSharedKey, config.PreSharedKey)
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
}
if config.SSHKey == "" {
@@ -254,17 +236,6 @@ func update(input ConfigInput) (*Config, error) {
config.WgPort = iface.DefaultWgPort
refresh = true
}
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
refresh = true
}
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
refresh = true
}
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
config.NATExternalIPs = input.NATExternalIPs
refresh = true
@@ -275,11 +246,6 @@ func update(input ConfigInput) (*Config, error) {
refresh = true
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
refresh = true
}
if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil {
@@ -307,9 +273,9 @@ func parseURL(serviceName, serviceURL string) (*url.URL, error) {
if parsedMgmtURL.Port() == "" {
switch parsedMgmtURL.Scheme {
case "https":
parsedMgmtURL.Host += ":443"
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
case "http":
parsedMgmtURL.Host += ":80"
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
default:
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
}
@@ -339,86 +305,3 @@ func configFileIsExists(path string) bool {
_, err := os.Stat(path)
return !os.IsNotExist(err)
}
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err
}
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
if err != nil {
return nil, err
}
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
// only do the check for the NetBird's managed version
return config, nil
}
var mgmTlsEnabled bool
if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true
}
if !mgmTlsEnabled {
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
return config, nil
}
if config.ManagementURL.Port() != managementLegacyPortString &&
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
return config, nil
}
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
if err != nil {
return nil, err
}
// here we check whether we could switch from the legacy 33073 port to the new 443
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
config.ManagementURL.String(), newURL.String())
key, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
defer func() {
err = client.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}
// everything is alright => update the config
newConfig, err := UpdateConfig(ConfigInput{
ManagementURL: newURL.String(),
ConfigPath: configPath,
})
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, fmt.Errorf("failed updating config file: %v", err)
}
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
return newConfig, nil
}

View File

@@ -1,16 +1,13 @@
package internal
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
"github.com/stretchr/testify/assert"
)
func TestGetConfig(t *testing.T) {
@@ -26,6 +23,9 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
if err != nil {
return
}
managementURL := "https://test.management.url:33071"
adminURL := "https://app.admin.url:443"
path := filepath.Join(t.TempDir(), "config.json")
@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 4: existing config, but new managementURL has been provided -> update config
// case 4: new empty pre-shared key config -> fetch it
newPreSharedKey := ""
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: path,
PreSharedKey: &newPreSharedKey,
})
if err != nil {
return
}
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)
// case 5: existing config, but new managementURL has been provided -> update config
newManagementURL := "https://test.newManagement.url:33071"
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: newManagementURL,
@@ -122,60 +137,3 @@ func TestHiddenPreSharedKey(t *testing.T) {
})
}
}
func TestUpdateOldManagementURL(t *testing.T) {
tests := []struct {
name string
previousManagementURL string
expectedManagementURL string
fileShouldNotChange bool
}{
{
name: "Update old management URL with legacy port",
previousManagementURL: "https://api.wiretrustee.com:33073",
expectedManagementURL: DefaultManagementURL,
},
{
name: "Update old management URL",
previousManagementURL: oldDefaultManagementURL,
expectedManagementURL: DefaultManagementURL,
},
{
name: "No update needed when management URL is up to date",
previousManagementURL: DefaultManagementURL,
expectedManagementURL: DefaultManagementURL,
fileShouldNotChange: true,
},
{
name: "No update needed when not using cloud management",
previousManagementURL: "https://netbird.example.com:33073",
expectedManagementURL: "https://netbird.example.com:33073",
fileShouldNotChange: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ManagementURL: tt.previousManagementURL,
ConfigPath: configPath,
})
require.NoError(t, err, "failed to create testing config")
previousStats, err := os.Stat(configPath)
require.NoError(t, err, "failed to create testing config stats")
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
require.NoError(t, err, "got error when updating old management url")
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
newStats, err := os.Stat(configPath)
require.NoError(t, err, "failed to create testing config stats")
switch tt.fileShouldNotChange {
case true:
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
case false:
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
}
})
}
}

View File

@@ -13,8 +13,8 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -22,7 +22,6 @@ import (
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/version"
)
// RunClient with main logic.
@@ -31,30 +30,19 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
}
// RunClientMobile with main logic on mobile system
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager) error {
mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -108,7 +96,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
cancel()
}()
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
@@ -191,6 +179,8 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
statusRecorder.ClientStart()
<-engineCtx.Done()
statusRecorder.ClientTeardown()
@@ -211,7 +201,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return nil
}
statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
@@ -235,7 +224,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
}
if config.PreSharedKey != "" {
@@ -284,6 +272,83 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
return loginResp, nil
}
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err
}
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() {
// only do the check for the NetBird's managed version
return config, nil
}
var mgmTlsEnabled bool
if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true
}
if !mgmTlsEnabled {
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
return config, nil
}
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
if err != nil {
return nil, err
}
// here we check whether we could switch from the legacy 33073 port to the new 443
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
config.ManagementURL.String(), newURL.String())
key, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
defer func() {
err = client.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}
// everything is alright => update the config
newConfig, err := UpdateConfig(ConfigInput{
ManagementURL: newURL.String(),
ConfigPath: configPath,
})
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, fmt.Errorf("failed updating config file: %v", err)
}
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
return newConfig, nil
}
return config, nil
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
var sri interface{} = statusRecorder
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)

View File

@@ -3,26 +3,29 @@
package dns
import (
"bufio"
"bytes"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
const (
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n"
)
const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)
var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
type fileConfigurator struct {
originalPerms os.FileMode
}
@@ -35,14 +38,14 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false
}
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
backupFileExist := false
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
if err == nil {
backupFileExist = true
}
if !config.RouteAll {
if !config.routeAll {
if backupFileExist {
err = f.restore()
if err != nil {
@@ -51,39 +54,53 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
}
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
managerType, err := getOSDNSManagerType()
if err != nil {
return err
}
switch managerType {
case fileManager, netbirdManager:
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}
}
default:
// todo improve this and maybe restart DNS manager from scratch
return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent")
}
searchDomainList := searchDomains(config)
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
}
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Error(err)
searchDomains += " " + dConf.domain
appendedDomains++
}
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.ServerIP}, nameServers...),
others)
log.Debugf("creating managed file %s", defaultResolvConfPath)
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
if err != nil {
restoreErr := f.restore()
if restoreErr != nil {
err = f.restore()
if err != nil {
log.Errorf("attempt to restore default file failed with error: %s", err)
}
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
return err
}
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)
return nil
}
@@ -115,138 +132,15 @@ func (f *fileConfigurator) restore() error {
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
log.Debugf("creating managed file %s", fileName)
var buf bytes.Buffer
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
for _, cfgLine := range others {
buf.WriteString(cfgLine)
buf.WriteString("\n")
}
if len(searchDomains) > 0 {
buf.WriteString("search ")
buf.WriteString(strings.Join(searchDomains, " "))
buf.WriteString("\n")
}
for _, ns := range nameServers {
buf.WriteString("nameserver ")
buf.WriteString(ns)
buf.WriteString("\n")
}
return buf
}
func searchDomains(config HostDNSConfig) []string {
listOfDomains := make([]string, 0)
for _, dConf := range config.Domains {
if dConf.MatchOnly || dConf.Disabled {
continue
}
listOfDomains = append(listOfDomains, dConf.Domain)
}
return listOfDomains
}
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
file, err := os.Open(resolvconfFile)
buf.WriteString(content)
err := os.WriteFile(fileName, buf.Bytes(), permissions)
if err != nil {
err = fmt.Errorf(`could not read existing resolv.conf`)
return
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
}
defer file.Close()
reader := bufio.NewReader(file)
for {
lineBytes, isPrefix, readErr := reader.ReadLine()
if readErr != nil {
break
}
if isPrefix {
err = fmt.Errorf(`resolv.conf line too long`)
return
}
line := strings.TrimSpace(string(lineBytes))
if strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "domain") {
continue
}
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
line = strings.ReplaceAll(line, "rotate", "")
splitLines := strings.Fields(line)
if len(splitLines) == 1 {
continue
}
line = strings.Join(splitLines, " ")
}
if strings.HasPrefix(line, "search") {
splitLines := strings.Fields(line)
if len(splitLines) < 2 {
continue
}
searchDomains = splitLines[1:]
continue
}
if strings.HasPrefix(line, "nameserver") {
splitLines := strings.Fields(line)
if len(splitLines) != 2 {
continue
}
nameServers = append(nameServers, splitLines[1])
continue
}
others = append(others, line)
}
return
}
// merge search Domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
lineSize := len("search")
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
return searchDomainsList
}
// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
// extend s slice with vs elements
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
for _, sd := range vs {
tmpCharsNumber := initialLineChars + 1 + len(sd)
if tmpCharsNumber > fileMaxLineCharsLimit {
// lets log all skipped Domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
continue
}
initialLineChars = tmpCharsNumber
if len(*s) >= fileMaxNumberOfSearchDomains {
// lets log all skipped Domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
continue
}
*s = append(*s, sd)
}
return initialLineChars
return nil
}
func copyFile(src, dest string) error {

View File

@@ -1,62 +0,0 @@
package dns
import (
"fmt"
"testing"
)
func Test_mergeSearchDomains(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"a", "b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 4 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
}
}
func Test_mergeSearchTooMuchDomains(t *testing.T) {
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
originDomains := []string{"h", "i"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}
func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"c", "d", "e", "f", "g"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}
func Test_mergeSearchTooLongDomain(t *testing.T) {
searchDomains := []string{getLongLine()}
originDomains := []string{"b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}
searchDomains = []string{"b"}
originDomains = []string{getLongLine()}
mergedDomains = mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}
}
func getLongLine() string {
x := "search "
for {
for i := 0; i <= 9; i++ {
if len(x) > fileMaxLineCharsLimit {
return x
}
x = fmt.Sprintf("%s%d", x, i)
}
}
}

View File

@@ -8,31 +8,31 @@ import (
)
type hostManager interface {
applyDNSConfig(config HostDNSConfig) error
applyDNSConfig(config hostDNSConfig) error
restoreHostDNS() error
supportCustomPort() bool
}
type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"`
ServerIP string `json:"serverIP"`
ServerPort int `json:"serverPort"`
type hostDNSConfig struct {
domains []domainConfig
routeAll bool
serverIP string
serverPort int
}
type DomainConfig struct {
Disabled bool `json:"disabled"`
Domain string `json:"domain"`
MatchOnly bool `json:"matchOnly"`
type domainConfig struct {
disabled bool
domain string
matchOnly bool
}
type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig) error
applyDNSConfigFunc func(config hostDNSConfig) error
restoreHostDNSFunc func() error
supportCustomPortFunc func() bool
}
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config)
}
@@ -55,38 +55,38 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
func newNoopHostMocker() hostManager {
return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true },
}
}
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig {
config := HostDNSConfig{
RouteAll: false,
ServerIP: ip,
ServerPort: port,
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
config := hostDNSConfig{
routeAll: false,
serverIP: ip,
serverPort: port,
}
for _, nsConfig := range dnsConfig.NameServerGroups {
if len(nsConfig.NameServers) == 0 {
continue
}
if nsConfig.Primary {
config.RouteAll = true
config.routeAll = true
}
for _, domain := range nsConfig.Domains {
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(domain, "."),
MatchOnly: !nsConfig.SearchDomainsEnabled,
config.domains = append(config.domains, domainConfig{
domain: strings.TrimSuffix(domain, "."),
matchOnly: true,
})
}
}
for _, customZone := range dnsConfig.CustomZones {
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."),
MatchOnly: false,
config.domains = append(config.domains, domainConfig{
domain: strings.TrimSuffix(customZone.Domain, "."),
matchOnly: false,
})
}

View File

@@ -7,7 +7,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
return &androidHostManager{}, nil
}
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error {
return nil
}

View File

@@ -1,5 +1,3 @@
//go:build !ios
package dns
import (
@@ -44,11 +42,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true
}
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if config.routeAll {
err = s.addDNSSetupForAll(config.serverIP, config.serverPort)
if err != nil {
return err
}
@@ -58,7 +56,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
return err
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort)
}
var (
@@ -66,20 +64,20 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string
)
for _, dConf := range config.Domains {
if dConf.Disabled {
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain)
if dConf.matchOnly {
matchDomains = append(matchDomains, dConf.domain)
continue
}
searchDomains = append(searchDomains, dConf.Domain)
searchDomains = append(searchDomains, dConf.domain)
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort)
} else {
log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey)
@@ -90,7 +88,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort)
} else {
log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey)
@@ -184,11 +182,12 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
}
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver := s.getPrimaryService()
primaryServiceKey := s.getPrimaryService()
if primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key")
}
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port)
if err != nil {
return err
}
@@ -197,32 +196,27 @@ func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string) {
func (s *systemConfigurator) getPrimaryService() string {
line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
log.Error("got error while sending the command: ", err)
return "", ""
return ""
}
scanner := bufio.NewScanner(bytes.NewReader(b))
primaryService := ""
router := ""
for scanner.Scan() {
text := scanner.Text()
if strings.Contains(text, "PrimaryService") {
primaryService = strings.TrimSpace(strings.Split(text, ":")[1])
}
if strings.Contains(text, "Router") {
router = strings.TrimSpace(strings.Split(text, ":")[1])
return strings.TrimSpace(strings.Split(text, ":")[1])
}
}
return primaryService, router
return ""
}
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)

View File

@@ -1,37 +0,0 @@
package dns
import (
"encoding/json"
log "github.com/sirupsen/logrus"
)
type iosHostManager struct {
dnsManager IosDnsManager
config HostDNSConfig
}
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
return &iosHostManager{
dnsManager: dnsManager,
}, nil
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
jsonData, err := json.Marshal(config)
if err != nil {
return err
}
jsonString := string(jsonData)
log.Debugf("Applying DNS settings: %s", jsonString)
a.dnsManager.ApplyDns(jsonString)
return nil
}
func (a iosHostManager) restoreHostDNS() error {
return nil
}
func (a iosHostManager) supportCustomPort() bool {
return false
}

View File

@@ -25,30 +25,13 @@ const (
type osManagerType int
func (t osManagerType) String() string {
switch t {
case netbirdManager:
return "netbird"
case fileManager:
return "file"
case networkManager:
return "networkManager"
case systemdManager:
return "systemd"
case resolvConfManager:
return "resolvconf"
default:
return "unknown"
}
}
func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err
}
log.Debugf("discovered mode is: %s", osManager)
log.Debugf("discovered mode is: %d", osManager)
switch osManager {
case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface)
@@ -82,6 +65,7 @@ func getOSDNSManagerType() (osManagerType, error) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {

View File

@@ -22,11 +22,13 @@ const (
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
)
type registryConfigurator struct {
guid string
routingAll bool
guid string
routingAll bool
existingSearchDomains []string
}
func newHostManager(wgInterface WGIface) (hostManager, error) {
@@ -43,10 +45,10 @@ func (s *registryConfigurator) supportCustomPort() bool {
return false
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
var err error
if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP)
if config.routeAll {
err = r.addDNSSetupForAll(config.serverIP)
if err != nil {
return err
}
@@ -56,7 +58,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
return err
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP)
}
var (
@@ -64,18 +66,18 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string
)
for _, dConf := range config.Domains {
if dConf.Disabled {
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if !dConf.MatchOnly {
searchDomains = append(searchDomains, dConf.Domain)
if !dConf.matchOnly {
searchDomains = append(searchDomains, dConf.domain)
}
matchDomains = append(matchDomains, "."+dConf.Domain)
matchDomains = append(matchDomains, "."+dConf.domain)
}
if len(matchDomains) != 0 {
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
err = r.addDNSMatchPolicy(matchDomains, config.serverIP)
} else {
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
}
@@ -146,11 +148,30 @@ func (r *registryConfigurator) restoreHostDNS() error {
log.Error(err)
}
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey)
return r.updateSearchDomains([]string{})
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
if err != nil {
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
}
valueList := strings.Split(value, ",")
setExisting := false
if len(r.existingSearchDomains) == 0 {
r.existingSearchDomains = valueList
setExisting = true
}
if len(domains) == 0 && setExisting {
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
return nil
}
newList := append(r.existingSearchDomains, domains...)
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
if err != nil {
return fmt.Errorf("adding search domain failed with error: %s", err)
}
@@ -214,3 +235,33 @@ func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
}
return nil
}
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
}
defer regKey.Close()
val, _, err := regKey.GetStringValue(key)
if err != nil {
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
}
return val, nil
}
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
}
defer regKey.Close()
err = regKey.SetStringValue(key, value)
if err != nil {
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
}
return nil
}

View File

@@ -1,12 +1,10 @@
package dns
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
"testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {

View File

@@ -2,7 +2,6 @@ package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -33,7 +32,7 @@ func (m *MockServer) DnsIP() string {
}
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
// TODO implement me
//TODO implement me
panic("implement me")
}
@@ -44,7 +43,3 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
}
return fmt.Errorf("method UpdateDNSServer is not implemented")
}
func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}

View File

@@ -7,14 +7,13 @@ import (
"encoding/binary"
"fmt"
"net/netip"
"regexp"
"time"
"github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbversion "github.com/netbirdio/netbird/version"
)
const (
@@ -94,7 +93,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
return false
}
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil {
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
@@ -102,7 +101,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
connSettings.cleanDeprecatedSettings()
dnsIP, err := netip.ParseAddr(config.ServerIP)
dnsIP, err := netip.ParseAddr(config.serverIP)
if err != nil {
return fmt.Errorf("unable to parse ip address, error: %s", err)
}
@@ -112,33 +111,33 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
searchDomains []string
matchDomains []string
)
for _, dConf := range config.Domains {
if dConf.Disabled {
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
if dConf.matchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
continue
}
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
}
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
newDomainList := append(searchDomains, matchDomains...)
priority := networkManagerDbusSearchDomainOnlyPriority
switch {
case config.RouteAll:
case config.routeAll:
priority = networkManagerDbusPrimaryDNSPriority
newDomainList = append(newDomainList, "~.")
if !n.routingAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
}
case len(matchDomains) > 0:
priority = networkManagerDbusWithMatchDomainPriority
}
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
n.routingAll = false
}
@@ -290,7 +289,12 @@ func isNetworkManagerSupportedVersion() bool {
}
func parseVersion(inputVersion string) (*version.Version, error) {
if inputVersion == "" || !nbversion.SemverRegexp.MatchString(inputVersion) {
reg, err := regexp.Compile(version.SemverRegexpRaw)
if err != nil {
return nil, err
}
if inputVersion == "" || !reg.MatchString(inputVersion) {
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
}

View File

@@ -1,57 +0,0 @@
package dns
import (
"reflect"
"sort"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
)
type notifier struct {
listener listener.NetworkChangeListener
listenerMux sync.Mutex
searchDomains []string
}
func newNotifier(initialSearchDomains []string) *notifier {
sort.Strings(initialSearchDomains)
return &notifier{
searchDomains: initialSearchDomains,
}
}
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *notifier) onNewSearchDomains(searchDomains []string) {
sort.Strings(searchDomains)
if len(n.searchDomains) != len(searchDomains) {
n.searchDomains = searchDomains
n.notify()
return
}
if reflect.DeepEqual(n.searchDomains, searchDomains) {
return
}
n.searchDomains = searchDomains
n.notify()
}
func (n *notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged("")
}(n.listener)
}

View File

@@ -3,9 +3,9 @@
package dns
import (
"bytes"
"fmt"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
)
@@ -14,24 +14,11 @@ const resolvconfCommand = "resolvconf"
type resolvconf struct {
ifaceName string
originalSearchDomains []string
originalNameServers []string
othersConfigs []string
}
// supported "openresolv" only
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf")
if err != nil {
log.Error(err)
}
return &resolvconf{
ifaceName: wgInterface.Name(),
originalSearchDomains: originalSearchDomains,
originalNameServers: nameServers,
othersConfigs: others,
ifaceName: wgInterface.Name(),
}, nil
}
@@ -39,9 +26,9 @@ func (r *resolvconf) supportCustomPort() bool {
return false
}
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
var err error
if !config.RouteAll {
if !config.routeAll {
err = r.restoreHostDNS()
if err != nil {
log.Error(err)
@@ -49,20 +36,37 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
}
searchDomainList := searchDomains(config)
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.ServerIP}, r.originalNameServers...),
r.othersConfigs)
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
err = r.applyConfig(buf)
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
}
searchDomains += " " + dConf.domain
appendedDomains++
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains)
err = r.applyConfig(content)
if err != nil {
return err
}
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains)
return nil
}
@@ -75,12 +79,12 @@ func (r *resolvconf) restoreHostDNS() error {
return nil
}
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
func (r *resolvconf) applyConfig(content string) error {
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
cmd.Stdin = &content
cmd.Stdin = strings.NewReader(content)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
}
return nil
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -19,11 +18,6 @@ type ReadyListener interface {
OnReady()
}
// IosDnsManager is a dns manager interface for iOS
type IosDnsManager interface {
ApplyDns(string)
}
// Server is a dns server interface
type Server interface {
Initialize() error
@@ -31,7 +25,6 @@ type Server interface {
DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
}
type registeredHandlerMap map[string]handlerWithStop
@@ -48,16 +41,12 @@ type DefaultServer struct {
hostManager hostManager
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfig hostDNSConfig
// permanent related properties
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
// make sense on mobile only
searchDomainNotifier *notifier
iosDnsManager IosDnsManager
}
type handlerWithStop interface {
@@ -92,26 +81,16 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
ds.searchDomainNotifier.setListener(listener)
setServerDns(ds)
return ds
}
// NewDefaultServerIos returns a new dns server. It optimized for ios
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
@@ -144,8 +123,8 @@ func (s *DefaultServer) Initialize() (err error) {
}
}
s.hostManager, err = s.initialize()
return err
s.hostManager, err = newHostManager(s.wgInterface)
return
}
// DnsIP returns the DNS resolver server IP address
@@ -233,23 +212,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
}
}
func (s *DefaultServer) SearchDomains() []string {
var searchDomains []string
for _, dConf := range s.currentConfig.Domains {
if dConf.Disabled {
continue
}
if dConf.MatchOnly {
continue
}
searchDomains = append(searchDomains, dConf.Domain)
}
return searchDomains
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if update.ServiceEnable {
_ = s.service.Listen()
@@ -265,7 +229,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
@@ -274,18 +238,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.RouteAll = false
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false
}
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
log.Error(err)
}
if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
}
return nil
}
@@ -325,13 +285,10 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue
}
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
}
handler := newUpstreamResolver(s.ctx)
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
@@ -349,7 +306,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
// reapply DNS settings, but it not touch the original configuration and serial number
// because it is temporal deactivation until next try
//
// after some period defined by upstream it tries to reactivate self by calling this hook
// after some period defined by upstream it trys to reactivate self by calling this hook
// everything we need here is just to re-apply current configuration because it already
// contains this upstream settings (temporal deactivation not removed it)
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
@@ -461,14 +418,14 @@ func (s *DefaultServer) upstreamCallbacks(
}
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false
s.currentConfig.routeAll = false
}
for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true
s.service.DeregisterMux(item.Domain)
removeIndex[item.Domain] = i
for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true
s.service.DeregisterMux(item.domain)
removeIndex[item.domain] = i
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
@@ -480,46 +437,31 @@ func (s *DefaultServer) upstreamCallbacks(
defer s.mux.Unlock()
for domain, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
continue
}
s.currentConfig.Domains[i].Disabled = false
s.currentConfig.domains[i].disabled = false
s.service.RegisterMux(domain, handler)
}
l := log.WithField("nameservers", nsGroup.NameServers)
l.Debug("reactivate temporary Disabled nameserver group")
l.Debug("reactivate temporary disabled nameserver group")
if nsGroup.Primary {
s.currentConfig.RouteAll = true
s.currentConfig.routeAll = true
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary Disabled nameserver group, DNS update apply")
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
return
}
func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler := newUpstreamResolver(s.ctx)
handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList {
a, err := netip.ParseAddr(ua)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
continue
}
ipString := ua
if !a.Is4() {
ipString = fmt.Sprintf("[%s]", ua)
}
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
}
handler.deactivate = func() {}
handler.reactivate = func() {}

View File

@@ -1,5 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@@ -1,7 +0,0 @@
//go:build !ios
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@@ -19,6 +19,6 @@ func TestGetServerDns(t *testing.T) {
}
if srvB != srv {
t.Errorf("mismatch dns instances")
t.Errorf("missmatch dns instances")
}
}

View File

@@ -1,5 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.iosDnsManager)
}

View File

@@ -1,7 +0,0 @@
//go:build !android
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -251,12 +250,11 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -324,17 +322,16 @@ func TestUpdateDNSServer(t *testing.T) {
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
@@ -342,7 +339,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
t.Errorf("crate and init wireguard interface: %v", err)
return
}
defer func() {
@@ -530,8 +527,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
registeredMap: make(registrationMap),
},
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
currentConfig: hostDNSConfig{
domains: []domainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
@@ -540,13 +537,13 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
for _, item := range config.domains {
if item.disabled {
continue
}
domains = append(domains, item.Domain)
domains = append(domains, item.domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
@@ -562,11 +559,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
deactivate()
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
for _, item := range server.currentConfig.domains {
if item.disabled {
continue
}
domains = append(domains, item.Domain)
domains = append(domains, item.domain)
}
got := strings.Join(domains, ",")
if expected != got {
@@ -576,11 +573,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
for _, item := range server.currentConfig.domains {
if item.disabled {
continue
}
domains = append(domains, item.Domain)
domains = append(domains, item.domain)
}
got = strings.Join(domains, ",")
if expected != got {
@@ -596,8 +593,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
defer wgIFace.Close()
var dnsList []string
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -620,8 +616,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -712,8 +708,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -774,19 +770,17 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err
return nil, nil
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
@@ -794,7 +788,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
err = wgIface.Create()
if err != nil {
t.Fatalf("create and init wireguard interface: %v", err)
t.Fatalf("crate and init wireguard interface: %v", err)
return nil, err
}

View File

@@ -1,5 +0,0 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@@ -11,9 +11,6 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
const (
@@ -27,11 +24,10 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
listenIP string
listenPort int
runtimeIP string
runtimePort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
@@ -47,7 +43,6 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service
UDPSize: 65535,
},
}
return s
}
@@ -60,21 +55,13 @@ func (s *serviceViaListener) Listen() error {
}
var err error
s.listenIP, s.listenPort, err = s.evalListenAddress()
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
if s.shouldApplyPortFwd() {
s.ebpfService = ebpf.GetEbpfManagerInstance()
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
if err != nil {
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
s.ebpfService = nil
}
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
@@ -82,10 +69,9 @@ func (s *serviceViaListener) Listen() error {
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
return nil
}
@@ -104,13 +90,6 @@ func (s *serviceViaListener) Stop() {
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -122,18 +101,11 @@ func (s *serviceViaListener) DeregisterMux(pattern string) {
}
func (s *serviceViaListener) RuntimePort() int {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil {
return defaultPort
} else {
return s.listenPort
}
return s.runtimePort
}
func (s *serviceViaListener) RuntimeIP() string {
return s.listenIP
return s.runtimeIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
@@ -164,30 +136,10 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
// resolution won't work.
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
// traffic on port 53 and forward it to a local DNS server running on 5053.
func (s *serviceViaListener) shouldApplyPortFwd() bool {
if runtime.GOOS != "linux" {
return false
}
if s.customAddr != nil {
return false
}
if s.listenPort == defaultPort {
return false
}
return true
}

View File

@@ -81,8 +81,8 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
return true
}
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
parsedIP, err := netip.ParseAddr(config.ServerIP)
func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
parsedIP, err := netip.ParseAddr(config.serverIP)
if err != nil {
return fmt.Errorf("unable to parse ip address, error: %s", err)
}
@@ -93,7 +93,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
if err != nil {
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.ServerIP, config.ServerPort, err)
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err)
}
var (
@@ -101,24 +101,24 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string
domainsInput []systemdDbusLinkDomainsInput
)
for _, dConf := range config.Domains {
if dConf.Disabled {
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.Domain),
MatchOnly: dConf.MatchOnly,
Domain: dns.Fqdn(dConf.domain),
MatchOnly: dConf.matchOnly,
})
if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain)
if dConf.matchOnly {
matchDomains = append(matchDomains, dConf.domain)
continue
}
searchDomains = append(searchDomains, dConf.Domain)
searchDomains = append(searchDomains, dConf.domain)
}
if config.RouteAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
if config.routeAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %s", err)
@@ -129,7 +129,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
})
s.routingAll = true
} else if s.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
}
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -22,15 +21,10 @@ const (
)
type upstreamClient interface {
exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
type upstreamResolver struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
@@ -46,25 +40,25 @@ type upstreamResolverBase struct {
reactivate func()
}
func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
ctx, cancel := context.WithCancel(parentCTX)
return &upstreamResolverBase{
return &upstreamResolver{
ctx: ctx,
cancel: cancel,
upstreamClient: &dns.Client{},
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
}
func (u *upstreamResolverBase) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
func (u *upstreamResolver) stop() {
log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
}
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
@@ -76,8 +70,10 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
rm, t, err := u.upstreamClient.exchange(upstream, r)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
@@ -87,19 +83,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream).
Error("got other error while querying the upstream")
return
}
if rm == nil {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return
}
// those checks need to be independent of each other due to memory address issues
if !rm.Response {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
Error("got an error while querying the upstream")
return
}
@@ -122,7 +106,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails() {
func (u *upstreamResolver) checkUpstreamFails() {
u.mutex.Lock()
defer u.mutex.Unlock()
@@ -134,18 +118,15 @@ func (u *upstreamResolverBase) checkUpstreamFails() {
case <-u.ctx.Done():
return
default:
// todo test the deactivation logic, it seems to affect the client
if runtime.GOOS != "ios" {
log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilResponse()
}
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilResponse()
}
}
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolverBase) waitUntilResponse() {
func (u *upstreamResolver) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
@@ -167,7 +148,10 @@ func (u *upstreamResolverBase) waitUntilResponse() {
var err error
for _, upstream := range u.upstreamServers {
_, _, err = u.upstreamClient.exchange(upstream, r)
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
_, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err == nil {
return nil

View File

@@ -1,93 +0,0 @@
//go:build ios
package dns
import (
"context"
"net"
"syscall"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
type upstreamResolverIOS struct {
*upstreamResolverBase
lIP net.IP
lNet *net.IPNet
iIndex int
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
iIndex: index,
}
ios.upstreamClient = ios
return ios, nil
}
func (u *upstreamResolverIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
log.Errorf("error while parsing upstream host: %s", err)
}
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate()
}
return client.Exchange(r, upstream)
}
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate() *dns.Client {
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: u.lIP,
Port: 0, // Let the OS pick a free port
},
Timeout: upstreamTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
}
if err := c.Control(fn); err != nil {
return err
}
if operr != nil {
log.Errorf("error while setting socket option: %s", operr)
}
return operr
},
}
client := &dns.Client{
Dialer: dialer,
}
return client
}
func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName)
return iface.Index, err
}

View File

@@ -1,32 +0,0 @@
//go:build !ios
package dns
import (
"context"
"net"
"time"
"github.com/miekg/dns"
)
type upstreamResolverNonIOS struct {
*upstreamResolverBase
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
nonIOS := &upstreamResolverNonIOS{
upstreamResolverBase: upstreamResolverBase,
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolverNonIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
cancel()
return rm, t, err
}

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"net"
"strings"
"testing"
"time"
@@ -50,6 +49,15 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
timeout: upstreamTimeout,
responseShouldBeNil: true,
},
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
}
// should resolve if first upstream times out
// should not write when both fails
@@ -58,7 +66,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
resolver := newUpstreamResolver(ctx)
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
@@ -106,12 +114,12 @@ type mockUpstreamResolver struct {
}
// ExchangeContext mock implementation of ExchangeContext from upstreamResolver
func (c mockUpstreamResolver) exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) {
return c.r, c.rtt, c.err
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := &upstreamResolverBase{
resolver := &upstreamResolver{
ctx: context.TODO(),
upstreamClient: &mockUpstreamResolver{
err: nil,
@@ -148,7 +156,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
}
if !resolver.disabled {
t.Errorf("resolver should be Disabled")
t.Errorf("resolver should be disabled")
return
}

Some files were not shown because too many files have changed in this diff Show More