mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 07:04:17 -04:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
840b07c784 |
@@ -1,8 +0,0 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
|
||||
[*.go]
|
||||
indent_style = tab
|
||||
9
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
9
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,14 +31,9 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -dA output:**
|
||||
**NetBird status -d output:**
|
||||
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
|
||||
**Do you face any (non-mobile) client issues?**
|
||||
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
If applicable, add the `netbird status -d' command output.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
|
||||
8
.github/workflows/golang-test-darwin.yml
vendored
8
.github/workflows/golang-test-darwin.yml
vendored
@@ -18,14 +18,14 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
35
.github/workflows/golang-test-freebsd.yml
vendored
35
.github/workflows/golang-test-freebsd.yml
vendored
@@ -13,7 +13,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test in FreeBSD
|
||||
@@ -21,26 +21,19 @@ jobs:
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
prepare: |
|
||||
pkg install -y go
|
||||
pkg install -y curl
|
||||
pkg install -y git
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -e -x
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
set -x
|
||||
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
|
||||
tar zxf go.tar.gz
|
||||
mv go /usr/local/go
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
go mod tidy
|
||||
go test -timeout 5m -p 1 ./iface/...
|
||||
go test -timeout 5m -p 1 ./client/...
|
||||
cd client
|
||||
go build .
|
||||
cd ..
|
||||
22
.github/workflows/golang-test-linux.yml
vendored
22
.github/workflows/golang-test-linux.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
@@ -49,18 +49,18 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
@@ -80,7 +80,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
@@ -124,4 +124,4 @@ jobs:
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
6
.github/workflows/golang-test-windows.yml
vendored
6
.github/workflows/golang-test-windows.yml
vendored
@@ -17,13 +17,13 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
|
||||
12
.github/workflows/golangci-lint.yml
vendored
12
.github/workflows/golangci-lint.yml
vendored
@@ -15,11 +15,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif
|
||||
ignore_words_list: erro,clienta,hastable,
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
@@ -32,15 +32,15 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
@@ -49,4 +49,4 @@ jobs:
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=12m
|
||||
args: --timeout=12m
|
||||
3
.github/workflows/install-script-test.yml
vendored
3
.github/workflows/install-script-test.yml
vendored
@@ -13,7 +13,6 @@ concurrency:
|
||||
jobs:
|
||||
test-install-script:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
@@ -22,7 +21,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -15,23 +15,23 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -50,11 +50,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
@@ -62,4 +62,4 @@ jobs:
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
CGO_ENABLED: 0
|
||||
174
.github/workflows/release.yml
vendored
174
.github/workflows/release.yml
vendored
@@ -3,16 +3,15 @@ name: Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
- 'v*'
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.14"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
SIGN_PIPE_VER: "v0.0.11"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -24,26 +23,22 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -51,57 +46,73 @@ jobs:
|
||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-releaser-
|
||||
- name: Install modules
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
-
|
||||
name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
-
|
||||
name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Login to Docker hub
|
||||
-
|
||||
name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
username: netbirdio
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc amd64
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows rsrc arm64
|
||||
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows rsrc arm
|
||||
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
|
||||
- name: Generate windows rsrc 386
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
args: release --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
- name: upload linux packages
|
||||
uses: actions/upload-artifact@v4
|
||||
-
|
||||
name: upload linux packages
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 3
|
||||
- name: upload windows packages
|
||||
uses: actions/upload-artifact@v4
|
||||
-
|
||||
name: upload windows packages
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 3
|
||||
- name: upload macos packages
|
||||
uses: actions/upload-artifact@v4
|
||||
-
|
||||
name: upload macos packages
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -110,29 +121,22 @@ jobs:
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
@@ -147,23 +151,22 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -174,17 +177,20 @@ jobs:
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -192,34 +198,52 @@ jobs:
|
||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-releaser-darwin-
|
||||
- name: Install modules
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
-
|
||||
name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
-
|
||||
name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
|
||||
trigger_signer:
|
||||
trigger_windows_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
needs: [release,release_ui]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
- name: Trigger Windows binaries sign pipeline
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
workflow: Sign windows bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
trigger_darwin_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release,release_ui_darwin]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger Darwin App binaries sign pipeline
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: Sign darwin ui app with dispatch
|
||||
repo: netbirdio/sign-pipelines
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
|
||||
97
.github/workflows/test-infrastructure-files.yml
vendored
97
.github/workflows/test-infrastructure-files.yml
vendored
@@ -18,31 +18,7 @@ concurrency:
|
||||
jobs:
|
||||
test-docker-compose:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
env:
|
||||
POSTGRES_USER: netbird
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: netbird
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
ports:
|
||||
- 5432:5432
|
||||
steps:
|
||||
- name: Set Database Connection String
|
||||
run: |
|
||||
if [ "${{ matrix.store }}" == "postgres" ]; then
|
||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
|
||||
else
|
||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Install jq
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
@@ -50,12 +26,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -63,7 +39,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: cp setup.env
|
||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||
@@ -82,8 +58,7 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
- name: check values
|
||||
@@ -110,8 +85,7 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
|
||||
@@ -149,14 +123,6 @@ jobs:
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
# check relay values
|
||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||
grep '33445:33445' docker-compose.yml
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -182,35 +148,26 @@ jobs:
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=0 go build -o netbird-relay main.go
|
||||
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
docker compose up -d
|
||||
docker-compose up -d
|
||||
sleep 5
|
||||
docker compose ps
|
||||
docker compose logs --tail=20
|
||||
docker-compose ps
|
||||
docker-compose logs --tail=20
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
|
||||
test $count -eq 5 || docker compose logs
|
||||
test $count -eq 4
|
||||
working-directory: infrastructure_files/artifacts
|
||||
|
||||
- name: test geolocation databases
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -219,7 +176,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
@@ -234,7 +191,7 @@ jobs:
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen postgres
|
||||
run: |
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
@@ -245,15 +202,12 @@ jobs:
|
||||
- name: test dashboard.env file gen postgres
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test relay.env file gen postgres
|
||||
run: test -f relay.env
|
||||
|
||||
- name: test zdb.env file gen postgres
|
||||
run: test -f zdb.env
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker compose down --volumes --rmi all
|
||||
docker-compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
@@ -272,7 +226,7 @@ jobs:
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen CockroachDB
|
||||
run: |
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
@@ -283,5 +237,20 @@ jobs:
|
||||
- name: test dashboard.env file gen CockroachDB
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test relay.env file gen CockroachDB
|
||||
run: test -f relay.env
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
GeoLite2-City*
|
||||
104
.goreleaser.yaml
104
.goreleaser.yaml
@@ -1,5 +1,3 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird
|
||||
@@ -24,7 +22,7 @@ builds:
|
||||
goarch: 386
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
@@ -44,19 +42,19 @@ builds:
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
- id: netbird-mgmt
|
||||
dir: management
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-mgmt
|
||||
goos:
|
||||
- linux
|
||||
@@ -66,7 +64,7 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-signal
|
||||
dir: signal
|
||||
@@ -80,21 +78,7 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-relay
|
||||
dir: relay
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-relay
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
@@ -102,6 +86,7 @@ archives:
|
||||
- netbird-static
|
||||
|
||||
nfpms:
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
@@ -176,52 +161,6 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
ids:
|
||||
@@ -374,18 +313,6 @@ docker_manifests:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/relay:latest
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
@@ -417,9 +344,10 @@ docker_manifests:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
-
|
||||
ids:
|
||||
- default
|
||||
repository:
|
||||
tap:
|
||||
owner: netbirdio
|
||||
name: homebrew-tap
|
||||
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
|
||||
@@ -436,7 +364,7 @@ brews:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-deb
|
||||
- netbird-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
@@ -458,4 +386,4 @@ checksum:
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./release_files/install.sh
|
||||
@@ -1,5 +1,3 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
@@ -13,7 +11,9 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
dir: client/ui
|
||||
@@ -28,7 +28,7 @@ builds:
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -H windowsgui
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
archives:
|
||||
- id: linux-arch
|
||||
@@ -41,6 +41,7 @@ archives:
|
||||
- netbird-ui-windows
|
||||
|
||||
nfpms:
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
@@ -78,7 +79,7 @@ nfpms:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-ui-deb
|
||||
- netbird-ui-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui-darwin
|
||||
@@ -19,7 +17,7 @@ builds:
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
@@ -30,4 +28,4 @@ archives:
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
|
||||
changelog:
|
||||
disable: true
|
||||
skip: true
|
||||
@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
|
||||
|
||||
**Goreleaser**
|
||||
```shell
|
||||
goreleaser build --snapshot --clean
|
||||
goreleaser --snapshot --rm-dist
|
||||
```
|
||||
**golangci-lint**
|
||||
```shell
|
||||
|
||||
@@ -10,14 +10,12 @@
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
</p>
|
||||
@@ -30,7 +28,7 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM alpine:3.20
|
||||
FROM alpine:3.19
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ type ConnectionListener interface {
|
||||
|
||||
// TunAdapter export internal TunAdapter for mobile
|
||||
type TunAdapter interface {
|
||||
device.TunAdapter
|
||||
iface.TunAdapter
|
||||
}
|
||||
|
||||
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||
@@ -51,7 +51,7 @@ func init() {
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter device.TunAdapter
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
@@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
|
||||
@@ -178,21 +178,6 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
|
||||
// domain names with random strings.
|
||||
func (a *Anonymizer) AnonymizeRoute(route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
func isWellKnown(addr netip.Addr) bool {
|
||||
wellKnown := []string{
|
||||
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
@@ -14,8 +13,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
@@ -66,17 +63,12 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
@@ -92,11 +84,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
@@ -125,11 +113,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
@@ -138,20 +122,17 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
|
||||
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
if !initialLevelTrace {
|
||||
@@ -164,11 +145,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
@@ -186,25 +162,21 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if restoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
cmd.Println("Netbird up")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
@@ -214,6 +186,16 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
|
||||
@@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
@@ -42,8 +42,6 @@ var downCmd = &cobra.Command{
|
||||
log.Errorf("call service down method: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Disconnected")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -39,11 +39,6 @@ var loginCmd = &cobra.Command{
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
@@ -67,7 +62,7 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -86,7 +81,7 @@ var loginCmd = &cobra.Command{
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
SetupKey: setupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ const (
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -56,7 +55,6 @@ var (
|
||||
managementURL string
|
||||
adminURL string
|
||||
setupKey string
|
||||
setupKeyPath string
|
||||
hostName string
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
@@ -71,7 +69,6 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
@@ -94,15 +91,12 @@ func init() {
|
||||
oldDefaultConfigPathDir = "/etc/wiretrustee/"
|
||||
oldDefaultLogFileDir = "/var/log/wiretrustee/"
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
|
||||
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
case "freebsd":
|
||||
defaultConfigPathDir = "/var/db/netbird/"
|
||||
}
|
||||
|
||||
defaultConfigPath = defaultConfigPathDir + "config.json"
|
||||
@@ -127,10 +121,8 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
|
||||
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
@@ -173,8 +165,6 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -256,21 +246,6 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
func getSetupKey() (string, error) {
|
||||
if setupKeyPath != "" && setupKey == "" {
|
||||
return getSetupKeyFromFile(setupKeyPath)
|
||||
}
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
|
||||
data, err := os.ReadFile(setupKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read setup key file: %v", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
func handleRebrand(cmd *cobra.Command) error {
|
||||
var err error
|
||||
if logFile == defaultLogFile {
|
||||
|
||||
@@ -4,10 +4,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestInitCommands(t *testing.T) {
|
||||
@@ -38,44 +34,3 @@ func TestInitCommands(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetFlagsFromEnvVars(t *testing.T) {
|
||||
var cmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Long: "test",
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`comma separated list of external IPs to map to the Wireguard interface`)
|
||||
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
|
||||
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
|
||||
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
|
||||
t.Setenv("NB_INTERFACE_NAME", "test-name")
|
||||
t.Setenv("NB_ENABLE_ROSENPASS", "true")
|
||||
t.Setenv("NB_WIREGUARD_PORT", "10000")
|
||||
err := cmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error while running netbird command, got %v", err)
|
||||
}
|
||||
if len(natExternalIPs) != 2 {
|
||||
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
|
||||
}
|
||||
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
|
||||
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
|
||||
}
|
||||
if interfaceName != "test-name" {
|
||||
t.Errorf("expected test-name, got %s", interfaceName)
|
||||
}
|
||||
if !rosenpassEnabled {
|
||||
t.Errorf("expected rosenpassEnabled to be true, got false")
|
||||
}
|
||||
if wireguardPort != 10000 {
|
||||
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,21 +2,18 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
|
||||
@@ -61,8 +61,6 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstance = serverInstance
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
log.Errorf("failed to serve daemon requests: %v", err)
|
||||
@@ -72,14 +70,6 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
if p.serverInstance != nil {
|
||||
in := new(proto.DownRequest)
|
||||
_, err := p.serverInstance.Down(p.ctx, in)
|
||||
if err != nil {
|
||||
log.Errorf("failed to stop daemon: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
p.cancel()
|
||||
|
||||
if p.serv != nil {
|
||||
|
||||
@@ -31,8 +31,6 @@ var installCmd = &cobra.Command{
|
||||
configPath,
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
|
||||
@@ -31,9 +31,9 @@ type peerStateDetailOutput struct {
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
@@ -335,18 +335,16 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
connType := ""
|
||||
peersConnected := 0
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
for _, pbPeerState := range peers {
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := ""
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
@@ -362,7 +360,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
@@ -375,6 +372,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
Status: pbPeerState.GetConnStatus(),
|
||||
LastStatusUpdate: timeLocal,
|
||||
ConnType: connType,
|
||||
Direct: pbPeerState.GetDirect(),
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
@@ -383,7 +381,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
RelayAddress: relayServerAddress,
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
@@ -644,9 +641,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" Direct: %t\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Relay server address: %s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
@@ -658,11 +655,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
peerState.PubKey,
|
||||
peerState.Status,
|
||||
peerState.ConnType,
|
||||
peerState.Direct,
|
||||
localICE,
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
peerState.RelayAddress,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
@@ -805,15 +802,12 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -849,8 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ var resp = &proto.StatusResponse{
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
@@ -56,6 +57,7 @@ var resp = &proto.StatusResponse{
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
@@ -135,6 +137,7 @@ var overview = statusOutputOverview{
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||
ConnType: "P2P",
|
||||
Direct: true,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
@@ -158,6 +161,7 @@ var overview = statusOutputOverview{
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||
ConnType: "Relayed",
|
||||
Direct: false,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
@@ -279,6 +283,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"direct": true,
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
@@ -287,7 +292,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
@@ -304,6 +308,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"direct": false,
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
@@ -312,7 +317,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
@@ -404,13 +408,13 @@ func TestParsingToYAML(t *testing.T) {
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
direct: true
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
@@ -424,13 +428,13 @@ func TestParsingToYAML(t *testing.T) {
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
direct: false
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
@@ -501,9 +505,9 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
Direct: true
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
@@ -517,9 +521,9 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
Direct: false
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
|
||||
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -10,7 +11,6 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -33,12 +33,18 @@ func startTestingServices(t *testing.T) string {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
config.Datadir = testDir
|
||||
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, signalLis := startSignal(t)
|
||||
signalAddr := signalLis.Addr().String()
|
||||
config.Signal.URI = signalAddr
|
||||
|
||||
_, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
|
||||
_, mgmLis := startManagement(t, config)
|
||||
mgmAddr := mgmLis.Addr().String()
|
||||
return mgmAddr
|
||||
}
|
||||
@@ -50,7 +56,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
|
||||
srv, err := sig.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
@@ -63,15 +69,14 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -83,17 +88,12 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -147,11 +147,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@@ -159,7 +154,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -168,10 +163,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
return connectClient.Run()
|
||||
}
|
||||
|
||||
@@ -207,13 +199,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
SetupKey: setupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
|
||||
@@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -41,36 +40,6 @@ func TestUpDaemon(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test the setup-key-file flag.
|
||||
tempFile, err := os.CreateTemp("", "setup-key")
|
||||
if err != nil {
|
||||
t.Errorf("could not create temp file, got error %v", err)
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
|
||||
t.Errorf("could not write to temp file, got error %v", err)
|
||||
return
|
||||
}
|
||||
if err := tempFile.Close(); err != nil {
|
||||
t.Errorf("unable to close file, got error %v", err)
|
||||
}
|
||||
rootCmd.SetArgs([]string{
|
||||
"login",
|
||||
"--daemon-addr", "tcp://" + cliAddr,
|
||||
"--setup-key-file", tempFile.Name(),
|
||||
"--log-file", "",
|
||||
})
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
t.Errorf("expected no error while running up command, got %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
|
||||
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
|
||||
return
|
||||
}
|
||||
|
||||
rootCmd.SetArgs([]string{
|
||||
"up",
|
||||
"--daemon-addr", "tcp://" + cliAddr,
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
@@ -19,22 +19,24 @@ const (
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
|
||||
postRoutingMark = "0x000007e4"
|
||||
)
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
@@ -59,7 +61,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
func (m *aclManager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -125,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -137,16 +139,28 @@ func (m *aclManager) AddPeerFiltering(
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
||||
if err != nil {
|
||||
return []firewall.Rule{rule}, err
|
||||
}
|
||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.chain == "PREROUTING" {
|
||||
goto DELETERULE
|
||||
}
|
||||
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
@@ -171,7 +185,14 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
|
||||
DELETERULE:
|
||||
var table string
|
||||
if r.chain == "PREROUTING" {
|
||||
table = "mangle"
|
||||
} else {
|
||||
table = "filter"
|
||||
}
|
||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
||||
if err != nil {
|
||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||
}
|
||||
@@ -182,6 +203,44 @@ func (m *aclManager) Reset() error {
|
||||
return m.cleanChains()
|
||||
}
|
||||
|
||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
||||
var src []string
|
||||
if ipsetName != "" {
|
||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
||||
} else {
|
||||
src = []string{"-s", ip.String()}
|
||||
}
|
||||
specs := []string{
|
||||
"-d", m.wgIface.Address().IP.String(),
|
||||
"-p", protocol,
|
||||
"--dport", port,
|
||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
||||
}
|
||||
|
||||
specs = append(src, specs...)
|
||||
|
||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: "PREROUTING",
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
@@ -232,6 +291,25 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
@@ -260,9 +338,17 @@ func (m *aclManager) createDefaultChains() error {
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
if chainName == "FORWARD" {
|
||||
// position 2 because we add it after router's, jump rule
|
||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,29 +356,40 @@ func (m *aclManager) createDefaultChains() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||
// We want to make sure our traffic is not dropped by existing rules.
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
|
||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
established := getConntrackEstablished()
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
||||
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
|
||||
m.appendToEntries("PREROUTING",
|
||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
@@ -359,3 +456,18 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -4,14 +4,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -22,7 +21,7 @@ type Manager struct {
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
router *routerManager
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -44,12 +43,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, iptablesClient, wgIface)
|
||||
m.router, err = newRouterManager(context, iptablesClient)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize route related chains: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||
return nil, err
|
||||
@@ -58,10 +57,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -74,62 +73,33 @@ func (m *Manager) AddPeerFiltering(
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources [] netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
return m.aclMgr.DeleteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -155,7 +125,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
@@ -168,7 +138,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddPeerFiltering(
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
@@ -183,7 +153,3 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -11,24 +11,9 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
@@ -55,8 +40,23 @@ func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), ifaceMock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
port := &fw.Port{
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
@@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(
|
||||
rule1, err = manager.AddFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
@@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
@@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
@@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
@@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -5,478 +5,368 @@ package iptables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
matchSet = "--match-set"
|
||||
)
|
||||
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
SetName string
|
||||
type routerManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
}
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
r := &router{
|
||||
m := &routerManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
err := m.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("cleanup routing rules: %s", err)
|
||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
err = r.createContainers()
|
||||
err = m.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("create containers for route: %s", err)
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
return m, err
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var setName string
|
||||
if len(sources) > 1 {
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: sources,
|
||||
Destination: destination,
|
||||
Proto: proto,
|
||||
SPort: sPort,
|
||||
DPort: dPort,
|
||||
Action: action,
|
||||
SetName: setName,
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return nil, fmt.Errorf("add route rule: %v", err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.GetRuleID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule []string) string {
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
return rule[i+3]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the current legacy management mode
|
||||
func (r *router) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
// insertRoutingRule inserts an iptables rule
|
||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
delete(i.rules, ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) Reset() error {
|
||||
var merr *multierror.Error
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
func (i *routerManager) RouteingFwChainName() string {
|
||||
return chainRTFWD
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
err := r.cleanJumpRules()
|
||||
func (i *routerManager) Reset() error {
|
||||
err := i.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules = make(map[string][]string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
table := tableFilter
|
||||
if chain == chainRTNAT {
|
||||
table = tableNat
|
||||
}
|
||||
|
||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s, error: %v", chain, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
if err := r.createAndSetupChain(chain); err != nil {
|
||||
return fmt.Errorf("create chain %s: %v", chain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
return fmt.Errorf("insert established rule: %v", err)
|
||||
}
|
||||
|
||||
return r.addJumpRules()
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
if chain == chainRTNAT {
|
||||
return tableNat
|
||||
}
|
||||
return tableFilter
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert established rule: %v", err)
|
||||
}
|
||||
|
||||
ruleKey := "established-" + chain
|
||||
r.rules[ruleKey] = establishedRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTNAT}
|
||||
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
}
|
||||
r.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
rule, found := r.rules[ipv4Nat]
|
||||
if found {
|
||||
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nat rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
if inverse {
|
||||
intdir = "-o"
|
||||
}
|
||||
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
if params.SetName != "" {
|
||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||
} else if len(params.Sources) > 0 {
|
||||
source := params.Sources[0]
|
||||
rule = append(rule, "-s", source.String())
|
||||
}
|
||||
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
}
|
||||
|
||||
rule = append(rule, "-j", actionToStr(params.Action))
|
||||
|
||||
return rule
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
func (i *routerManager) createContainers() error {
|
||||
if i.rules[Ipv4Forwarding] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||
errMSGFormat := "failed creating chain %s,error: %v"
|
||||
err := i.createChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
||||
}
|
||||
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(p)
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
err = i.createChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
||||
err = i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addJumpRules create jump rules to send packets to NetBird chains
|
||||
func (i *routerManager) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTFWD}
|
||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[Ipv4Forwarding] = rule
|
||||
|
||||
rule = []string{"-j", chainRTNAT}
|
||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||
func (i *routerManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
||||
rule, found := i.rules[Ipv4Forwarding]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4Nat]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createChain(table, newChain string) error {
|
||||
chains, err := i.iptablesClient.ListChains(table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
||||
}
|
||||
|
||||
shouldCreateChain := true
|
||||
for _, chain := range chains {
|
||||
if chain == newChain {
|
||||
shouldCreateChain = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreateChain {
|
||||
err = i.iptablesClient.NewChain(table, newChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
}
|
||||
|
||||
func getIptablesRuleType(table string) string {
|
||||
ruleType := "forwarding"
|
||||
if table == tableNat {
|
||||
ruleType = "nat"
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
|
||||
@@ -4,13 +4,11 @@ package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -30,7 +28,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
|
||||
defer func() {
|
||||
@@ -39,22 +37,26 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.100.0/24",
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
@@ -63,7 +65,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
@@ -74,7 +76,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
defer func() {
|
||||
@@ -84,13 +86,35 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
foundRule, found := manager.rules[forwardRuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "income forwarding rule should exist")
|
||||
|
||||
foundRule, found = manager.rules[inForwardRuleKey]
|
||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
@@ -103,8 +127,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
@@ -122,7 +146,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
@@ -132,7 +156,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
@@ -140,14 +164,26 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
@@ -155,14 +191,28 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "forwarding rule should not exist")
|
||||
|
||||
_, found := manager.rules[forwardRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "income forwarding rule should not exist")
|
||||
|
||||
_, found = manager.rules[inForwardRuleKey]
|
||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found := manager.rules[natRuleKey]
|
||||
_, found = manager.rules[natRuleKey]
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
@@ -171,175 +221,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.Skip("iptables not supported on this system")
|
||||
}
|
||||
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
require.NoError(t, err, "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with multiple sources",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: tt.sources,
|
||||
Destination: tt.destination,
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
SetName: "",
|
||||
}
|
||||
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
assert.True(t, exists, "IPSet not created")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
NatFormat = "netbird-nat-%s"
|
||||
ForwardingFormat = "netbird-fwd-%s"
|
||||
InNatFormat = "netbird-nat-in-%s"
|
||||
InForwardingFormat = "netbird-fwd-in-%s"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
@@ -55,11 +49,11 @@ type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
@@ -70,25 +64,17 @@ type Manager interface {
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
// DeleteRule from the firewall by rule definition
|
||||
DeleteRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
// InsertRoutingRules inserts a routing firewall rule
|
||||
InsertRoutingRules(pair RouterPair) error
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
// RemoveNatRule removes a routing NAT rule
|
||||
RemoveNatRule(pair RouterPair) error
|
||||
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair RouterPair) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
@@ -97,89 +83,6 @@ type Manager interface {
|
||||
Flush() error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
type LegacyManager interface {
|
||||
RemoveAllLegacyRouteRules() error
|
||||
GetLegacyManagement() bool
|
||||
SetLegacyManagement(bool)
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
oldLegacy := router.GetLegacyManagement()
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
router.SetLegacyManagement(isLegacy)
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to clean up the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
sortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
return prefixes
|
||||
}
|
||||
|
||||
merged := []netip.Prefix{prefixes[0]}
|
||||
for _, prefix := range prefixes[1:] {
|
||||
last := merged[len(merged)-1]
|
||||
if last.Contains(prefix.Addr()) {
|
||||
// If the current prefix is contained within the last merged prefix, skip it
|
||||
continue
|
||||
}
|
||||
if prefix.Contains(last.Addr()) {
|
||||
// If the current prefix contains the last merged prefix, replace it
|
||||
merged[len(merged)-1] = prefix
|
||||
} else {
|
||||
// Otherwise, add the current prefix to the merged list
|
||||
merged = append(merged, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// sortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func sortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
return addrCmp < 0
|
||||
}
|
||||
|
||||
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
|
||||
return prefixes[i].Bits() > prefixes[j].Bits()
|
||||
})
|
||||
func GenKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
}
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func TestGenerateSetName(t *testing.T) {
|
||||
t.Run("Different orders result in same hash", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Result format is correct", func(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Errorf("Result format is incorrect: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeIPRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []netip.Prefix{},
|
||||
expected: []netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Two non-overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partially overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IPv6 ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::/48"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := manager.MergeIPRanges(tt.input)
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,26 +1,18 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouterPair struct {
|
||||
ID route.ID
|
||||
Source netip.Prefix
|
||||
Destination netip.Prefix
|
||||
ID string
|
||||
Source string
|
||||
Destination string
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
}
|
||||
|
||||
func GetInversePair(pair RouterPair) RouterPair {
|
||||
func GetInPair(pair RouterPair) RouterPair {
|
||||
return RouterPair{
|
||||
ID: pair.ID,
|
||||
// invert Source/Destination
|
||||
Source: pair.Destination,
|
||||
Destination: pair.Source,
|
||||
Masquerade: pair.Masquerade,
|
||||
Inverse: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,10 +33,9 @@ const (
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
var (
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
@@ -49,6 +48,7 @@ type AclManager struct {
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
chainFwFilter *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@@ -64,7 +64,7 @@ type iFaceMapper interface {
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
@@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
func (m *AclManager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -120,11 +120,20 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules = append(newRules, ioRule)
|
||||
if !shouldAddToPrerouting(proto, dPort, direction) {
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
|
||||
if err != nil {
|
||||
return newRules, err
|
||||
}
|
||||
newRules = append(newRules, preroutingRule)
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *AclManager) DeleteRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
@@ -190,7 +199,8 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
|
||||
// input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Payload{
|
||||
@@ -204,13 +214,13 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
Mask: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
@@ -236,13 +246,13 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
Mask: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
@@ -256,8 +266,10 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
Exprs: expOut,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
err := m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create default allow rules: %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -278,11 +290,15 @@ func (m *AclManager) Flush() error {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
@@ -292,7 +308,18 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
}, nil
|
||||
}
|
||||
|
||||
var expressions []expr.Any
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
@@ -302,15 +329,21 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case firewall.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case firewall.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{protoData},
|
||||
Data: protoData,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -399,9 +432,10 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
} else {
|
||||
chain = m.chainOutputRules
|
||||
}
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
nftRule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
@@ -419,13 +453,139 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case firewall.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case firewall.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
r.nftSet,
|
||||
r.ruleID,
|
||||
ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var ipExpression expr.Any
|
||||
// add individual IP for match if no ipset defined
|
||||
rawIP := ip.To4()
|
||||
if ipset == nil {
|
||||
ipExpression = &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
}
|
||||
} else {
|
||||
ipExpression = &expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipset.Name,
|
||||
SetID: ipset.ID,
|
||||
}
|
||||
}
|
||||
|
||||
expressions := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
ipExpression,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: m.wgIface.Address().IP.To4(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: protoData,
|
||||
},
|
||||
}
|
||||
|
||||
if port != nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*port),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: postroutingMark,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
nftRule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: []byte(ruleId),
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
m.rules[ruleId] = rule
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
return err
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
@@ -441,6 +601,9 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
|
||||
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
|
||||
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
|
||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = m.rConn.Flush()
|
||||
@@ -452,6 +615,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// netbird-acl-output-filter
|
||||
// type filter hook output priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
||||
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
|
||||
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
||||
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
||||
@@ -463,15 +627,24 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward() // to netbird-rt-fwd
|
||||
m.addJumpRulesToRtForward() // to
|
||||
m.addMarkAccept()
|
||||
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
|
||||
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-output-filter
|
||||
// type filter hook output priority filter; policy accept;
|
||||
m.chainPrerouting = m.createPreroutingMangle()
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -494,6 +667,59 @@ func (m *AclManager) addJumpRulesToRtForward() {
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
expressions = []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routeingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addMarkAccept() {
|
||||
// oifname "wt0" meta mark 0x000007e4 accept
|
||||
// iifname "wt0" meta mark 0x000007e4 accept
|
||||
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
|
||||
for _, iface := range ifaces {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: iface, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: postroutingMark,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
@@ -503,9 +729,6 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
@@ -523,6 +746,74 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
|
||||
return m.rConn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: "netbird-acl-prerouting-filter",
|
||||
Table: m.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: m.wgIface.Address().IP.To4(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: postroutingMark,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
@@ -541,9 +832,101 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRuleToInputChain() {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.chainInputRules.Name,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
var srcOp, dstOp expr.CmpOp
|
||||
if netIfName == expr.MetaKeyIIFNAME {
|
||||
srcOp = expr.CmpOpEq
|
||||
dstOp = expr.CmpOpNeq
|
||||
} else {
|
||||
srcOp = expr.CmpOpNeq
|
||||
dstOp = expr.CmpOpEq
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: netIfName, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: srcOp,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: dstOp,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
dstOp := expr.CmpOpNeq
|
||||
var srcOp, dstOp expr.CmpOp
|
||||
if iifname == expr.MetaKeyIIFNAME {
|
||||
srcOp = expr.CmpOpNeq
|
||||
dstOp = expr.CmpOpEq
|
||||
} else {
|
||||
srcOp = expr.CmpOpEq
|
||||
dstOp = expr.CmpOpNeq
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: iifname, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -551,6 +934,24 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: srcOp,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@@ -581,6 +982,7 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -588,12 +990,47 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
@@ -695,7 +1132,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(
|
||||
func generateRuleId(
|
||||
ip net.IP,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
@@ -718,6 +1155,33 @@ func generatePeerRuleId(
|
||||
}
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
|
||||
// case of icmp port is empty
|
||||
var p string
|
||||
if port != nil {
|
||||
p = port.String()
|
||||
}
|
||||
if ipset != nil {
|
||||
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
|
||||
} else {
|
||||
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil && proto != firewall.ProtocolICMP {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func encodePort(port firewall.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
@@ -727,19 +1191,6 @@ func encodePort(port firewall.Port) []byte {
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
copy(b, []byte(n+"\x00"))
|
||||
return b
|
||||
}
|
||||
|
||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
return unix.IPPROTO_ICMP, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
@@ -5,11 +5,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -17,11 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||
tableNameNetbird = "netbird"
|
||||
|
||||
tableNameFilter = "filter"
|
||||
chainNameInput = "INPUT"
|
||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
||||
tableName = "netbird"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -46,12 +41,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, workTable, wgIface)
|
||||
m.router, err = newRouter(context, workTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,11 +54,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -81,52 +76,33 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
return m.aclManager.DeleteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
return m.router.AddRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveNatRule(pair)
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -150,7 +126,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@@ -181,27 +157,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
oldLegacy := m.router.legacyManagement
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
m.router.legacyManagement = isLegacy
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
@@ -230,16 +185,14 @@ func (m *Manager) Reset() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset forward rules: %v", err)
|
||||
}
|
||||
m.router.ResetForwardRules()
|
||||
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
@@ -265,12 +218,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
@@ -286,7 +239,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
@@ -296,7 +251,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
@@ -310,33 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conn.InsertRule(rule)
|
||||
}
|
||||
|
||||
@@ -9,30 +9,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
@@ -56,9 +40,23 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), ifaceMock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
@@ -72,7 +70,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(
|
||||
rule, err := manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
@@ -90,34 +88,17 @@ func TestNftablesManager(t *testing.T) {
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 2, "expected 2 rules")
|
||||
|
||||
expectedExprs1 := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
expectedExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@@ -153,10 +134,10 @@ func TestNftablesManager(t *testing.T) {
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeletePeerRule(r)
|
||||
err = manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
@@ -165,8 +146,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
@@ -207,9 +187,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
|
||||
431
client/firewall/nftables/route_linux.go
Normal file
431
client/firewall/nftables/route_linux.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRouteingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (r *router) RouteingFwChainName() string {
|
||||
return chainNameRouteingFw
|
||||
}
|
||||
|
||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||
func (r *router) ResetForwardRules() {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRouteingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
log.Debugf("add default accept forward rule")
|
||||
r.acceptForwardRule(pair.Source)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
var expression []expr.Any
|
||||
if isNat {
|
||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
||||
} else {
|
||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
||||
}
|
||||
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
_, exists := r.rules[ruleKey]
|
||||
if exists {
|
||||
err := r.removeRoutingRule(format, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||
}
|
||||
|
||||
r.conn.AddRule(rule)
|
||||
|
||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule = &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
r.conn.AddRule(rule)
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.rules) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
rule, found := r.rules[ruleKey]
|
||||
if found {
|
||||
ruleType := "forwarding"
|
||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||
ruleType = "nat"
|
||||
}
|
||||
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rules []*nftables.Rule
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
if chain.Name != "FORWARD" {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
ip, network, _ := net.ParseCIDR(cidr)
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
var offSet uint32
|
||||
if source {
|
||||
offSet = 12 // src offset
|
||||
} else {
|
||||
offSet = 16 // dst offset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: 4,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,798 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
chainNameForward = "FORWARD"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
r.deleteIpSet,
|
||||
)
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
return r.cleanUpDefaultForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.acceptForwardRules()
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
chain := r.chains[chainNameRoutingFw]
|
||||
var exprs []expr.Any
|
||||
|
||||
switch {
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||
case len(sources) == 1:
|
||||
// If there's only one source, we can use it directly
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
||||
default:
|
||||
// If there are multiple sources, create or get an ipset
|
||||
var err error
|
||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle destination
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
})
|
||||
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
|
||||
var verdict expr.VerdictKind
|
||||
if action == firewall.ActionAccept {
|
||||
verdict = expr.VerdictAccept
|
||||
} else {
|
||||
verdict = expr.VerdictDrop
|
||||
}
|
||||
exprs = append(exprs, &expr.Verdict{Kind: verdict})
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
|
||||
|
||||
return ruleKey, r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
)
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleKey := rule.GetRuleID()
|
||||
nftRule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
return fmt.Errorf("delete: %w", err)
|
||||
}
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range sources {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
|
||||
if err := r.conn.AddSet(set, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||
b := addr.As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||
func uint32ToBytes(ip uint32) [4]byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], ip)
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
r.conn.DelSet(set)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
return lookup.SetName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
log.Debugf("removed route rule %s", ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
dir := expr.MetaKeyIIFNAME
|
||||
if pair.Inverse {
|
||||
dir = expr.MetaKeyOIFNAME
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: dir,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
exprs = append(exprs,
|
||||
&expr.Counter{}, &expr.Masq{},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the route manager's legacy management mode
|
||||
func (r *router) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (r *router) acceptForwardRules() {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
}
|
||||
r.conn.InsertRule(iifRule)
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 2,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||
var offset uint32
|
||||
if source {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
// 0.0.0.0/0 doesn't need extra expressions
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, 32)
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
},
|
||||
// netmask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: mask,
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: prefix.Masked().Addr().AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
|
||||
offset := uint32(2) // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: offset,
|
||||
Len: 2,
|
||||
})
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
// Handle port range
|
||||
exprs = append(exprs,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpLte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Handle single port or multiple ports
|
||||
for i, p := range port.Values {
|
||||
if i > 0 {
|
||||
// Add a bitwise OR operation between port checks
|
||||
exprs = append(exprs, &expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
})
|
||||
}
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
@@ -4,15 +4,11 @@ package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -28,50 +24,56 @@ const (
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
defer func(manager *router, pair firewall.RouterPair) {
|
||||
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
||||
}(manager, testCase.InputPair)
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
@@ -86,20 +88,27 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
)
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
found = 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
@@ -113,37 +122,45 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
})
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
@@ -152,11 +169,20 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(inForwardRuleKey),
|
||||
})
|
||||
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
@@ -168,10 +194,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
manager.ResetForwardRules()
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
@@ -179,7 +204,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
}
|
||||
}
|
||||
@@ -188,468 +215,6 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
|
||||
defer func(r *router) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
}(r)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with multiple sources",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
for i, expr := range rule.Exprs {
|
||||
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||
}
|
||||
|
||||
// Verify internal rule content
|
||||
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
// Check if the rule exists in nftables and verify its content
|
||||
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||
require.NoError(t, err, "Failed to get rules from nftables")
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, nftRule, "Rule not found in nftables")
|
||||
t.Log("Actual nftables rule expressions:")
|
||||
for i, expr := range nftRule.Exprs {
|
||||
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||
}
|
||||
|
||||
// Verify actual nftables rule content
|
||||
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Single IP",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("10.0.0.1/32"),
|
||||
netip.MustParsePrefix("172.16.0.1/32"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Single Subnet",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
|
||||
},
|
||||
{
|
||||
name: "Multiple Subnets with Various Prefix Lengths",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("203.0.113.0/26"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mix of Single IPs and Subnets in Different Positions",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("10.0.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.1/32"),
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping IPs/Subnets",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("10.0.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.1/32"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add this helper function inside TestNftablesCreateIpSet
|
||||
printNftSets := func() {
|
||||
cmd := exec.Command("nft", "list", "sets")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Failed to run 'nft list sets': %v", err)
|
||||
} else {
|
||||
t.Logf("Current nft sets:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
printNftSets()
|
||||
require.NoError(t, err, "Failed to create IP set")
|
||||
}
|
||||
require.NotNil(t, set, "Created set is nil")
|
||||
|
||||
// Verify set properties
|
||||
assert.Equal(t, setName, set.Name, "Set name mismatch")
|
||||
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
|
||||
assert.True(t, set.Interval, "Set interval property should be true")
|
||||
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
|
||||
|
||||
// Fetch the created set from nftables
|
||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||
require.NoError(t, err, "Failed to fetch created set")
|
||||
require.NotNil(t, fetchedSet, "Fetched set is nil")
|
||||
|
||||
// Verify set elements
|
||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||
require.NoError(t, err, "Failed to get set elements")
|
||||
|
||||
// Count the number of unique prefixes (excluding interval end markers)
|
||||
uniquePrefixes := make(map[string]bool)
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd {
|
||||
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||
uniquePrefixes[ip.String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check against expected merged prefixes
|
||||
expectedCount := len(tt.expected)
|
||||
if expectedCount == 0 {
|
||||
expectedCount = len(tt.sources)
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
|
||||
|
||||
// Verify each expected prefix is in the set
|
||||
for _, expected := range tt.expected {
|
||||
found := false
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd {
|
||||
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||
if expected.Contains(ip) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected prefix %s not found in set", expected)
|
||||
}
|
||||
|
||||
r.conn.DelSet(set)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
t.Logf("Failed to delete set: %v", err)
|
||||
printNftSets()
|
||||
}
|
||||
require.NoError(t, err, "Failed to delete set")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotNil(t, rule, "Rule should not be nil")
|
||||
|
||||
// Verify sources and destination
|
||||
if expectSet {
|
||||
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
|
||||
} else if len(sources) == 1 && sources[0].Bits() != 0 {
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
|
||||
}
|
||||
}
|
||||
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
|
||||
}
|
||||
|
||||
// Verify protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
|
||||
}
|
||||
|
||||
// Verify ports
|
||||
if sPort != nil {
|
||||
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
|
||||
}
|
||||
if dPort != nil {
|
||||
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
|
||||
}
|
||||
|
||||
// Verify action
|
||||
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
|
||||
}
|
||||
|
||||
func containsSetLookup(exprs []expr.Any) bool {
|
||||
for _, e := range exprs {
|
||||
if _, ok := e.(*expr.Lookup); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
|
||||
var offset uint32
|
||||
if isSource {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
}
|
||||
|
||||
var payloadFound, bitwiseFound, cmpFound bool
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Payload:
|
||||
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Bitwise:
|
||||
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
|
||||
bitwiseFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
|
||||
cmpFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
|
||||
}
|
||||
|
||||
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
var offset uint32 = 2 // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
var payloadFound, portMatchFound bool
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Payload:
|
||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if port.IsRange {
|
||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
||||
portMatchFound = true
|
||||
}
|
||||
} else {
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||
for _, p := range port.Values {
|
||||
if uint16(p) == portValue {
|
||||
portMatchFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if payloadFound && portMatchFound {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Meta:
|
||||
if ex.Key == expr.MetaKeyL4PROTO {
|
||||
metaFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
|
||||
cmpFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return metaFound && cmpFound
|
||||
}
|
||||
|
||||
func containsAction(exprs []expr.Any, action firewall.Action) bool {
|
||||
for _, e := range exprs {
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
return verdict.Kind == expr.VerdictAccept
|
||||
case firewall.ActionDrop:
|
||||
return verdict.Kind == expr.VerdictDrop
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() int {
|
||||
nf := nftables.Conn{}
|
||||
@@ -685,12 +250,12 @@ func createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = sConn.Flush()
|
||||
|
||||
return table, err
|
||||
@@ -708,7 +273,7 @@ func deleteWorkTable() {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
//go:build !android
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
var (
|
||||
InsertRuleTestCases = []struct {
|
||||
@@ -15,8 +13,8 @@ var (
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
@@ -24,8 +22,8 @@ var (
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +38,8 @@ var (
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -3,7 +3,6 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -12,8 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
@@ -24,7 +22,7 @@ var (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
SetFilter(iface.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
@@ -105,26 +103,26 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -190,22 +188,8 @@ func (m *Manager) AddPeerFiltering(
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -231,11 +215,6 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(_ bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
@@ -358,6 +337,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
return rule.drop, true
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop, true
|
||||
}
|
||||
@@ -416,7 +396,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -424,7 +404,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,16 +11,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
SetFilterFunc func(iface.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
if i.SetFilterFunc == nil {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -36,7 +35,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -50,10 +49,10 @@ func TestManagerCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
func TestManagerAddFiltering(t *testing.T) {
|
||||
isSetFilterCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error {
|
||||
SetFilterFunc: func(iface.PacketFilter) error {
|
||||
isSetFilterCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -72,7 +71,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -91,7 +90,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
|
||||
func TestManagerDeleteRule(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -107,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -120,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule {
|
||||
err = m.DeletePeerRule(r)
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
@@ -141,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
err = m.DeletePeerRule(r)
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
@@ -237,7 +236,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -253,7 +252,7 @@ func TestManagerReset(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -272,7 +271,7 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -291,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -340,7 +339,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
func TestRemovePacketHook(t *testing.T) {
|
||||
// creating mock iface
|
||||
iface := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
@@ -389,7 +388,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
@@ -407,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package configurer
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrPeerNotFound = errors.New("peer not found")
|
||||
@@ -1,9 +0,0 @@
|
||||
package configurer
|
||||
|
||||
import "time"
|
||||
|
||||
type WGStats struct {
|
||||
LastHandshake time.Time
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
|
||||
type WGConfigurer interface {
|
||||
ConfigureInterface(privateKey string, port int) error
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close()
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func wgLogLevel() int {
|
||||
if os.Getenv("NB_WG_DEBUG") == "true" {
|
||||
return device.LogLevelVerbose
|
||||
} else {
|
||||
return device.LogLevelSilent
|
||||
}
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package device
|
||||
|
||||
// CustomWindowsGUIDString is a custom GUID string for the interface
|
||||
var CustomWindowsGUIDString string
|
||||
@@ -1,16 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
|
||||
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
// this function is different on Android
|
||||
func (w *WGIface) Create() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 20 * time.Millisecond,
|
||||
MaxElapsedTime: 500 * time.Millisecond,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
operation := func() error {
|
||||
cfgr, err := w.tun.Create()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.configurer = cfgr
|
||||
return nil
|
||||
}
|
||||
|
||||
return backoff.Retry(operation, backOff)
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func (w *WGIface) Destroy() error {
|
||||
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (w *WGIface) Destroy() error {
|
||||
link, err := netlink.LinkByName(w.Name())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkDel(link); err != nil {
|
||||
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build android || (ios && !darwin)
|
||||
|
||||
package iface
|
||||
|
||||
import "errors"
|
||||
|
||||
func (w *WGIface) Destroy() error {
|
||||
return errors.New("not supported on mobile")
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (w *WGIface) Destroy() error {
|
||||
netshCmd := GetSystem32Command("netsh")
|
||||
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
|
||||
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
|
||||
func GetSystem32Command(command string) string {
|
||||
_, err := exec.LookPath(command)
|
||||
if err == nil {
|
||||
return command
|
||||
}
|
||||
|
||||
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
|
||||
|
||||
return "C:\\windows\\system32\\" + command + ".exe"
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
CloseFunc func() error
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
||||
return m.IsUserspaceBindFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
||||
return m.RemovePeerFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
return m.SetFilterFunc(filter)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
||||
return m.GetFilterFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close() error
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close() error
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
type RuleID string
|
||||
|
||||
func (r RuleID) GetRuleID() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func GenerateRouteRuleKey(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto manager.Protocol,
|
||||
sPort *manager.Port,
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
@@ -25,18 +23,16 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
mutex sync.Mutex
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
firewall: fm,
|
||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||
routeRules: make(map[id.RuleID]struct{}),
|
||||
firewall: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +46,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
for _, pairs := range d.peerRulesPairs {
|
||||
for _, pairs := range d.rulesPairs {
|
||||
total += len(pairs)
|
||||
}
|
||||
log.Infof(
|
||||
@@ -63,34 +59,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
return
|
||||
}
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
defer func() {
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
}
|
||||
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
enableSSH := (networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled)
|
||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
|
||||
@@ -100,9 +83,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||
})
|
||||
}
|
||||
@@ -114,20 +97,20 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules = append(rules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
for _, r := range rules {
|
||||
@@ -147,97 +130,29 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
break
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
d.rulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
for pairID, rules := range d.rulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
delete(d.rulesPairs, pairID)
|
||||
}
|
||||
}
|
||||
d.peerRulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||
var newRouteRules = make(map[id.RuleID]struct{})
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply route ACL: %w", err)
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
}
|
||||
|
||||
for id := range d.routeRules {
|
||||
if _, ok := newRouteRules[id]; !ok {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
log.Errorf("failed to delete route firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(d.routeRules, id)
|
||||
}
|
||||
}
|
||||
d.routeRules = newRouteRules
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", fmt.Errorf("source ranges is empty")
|
||||
}
|
||||
|
||||
var sources []netip.Prefix
|
||||
for _, sourceRange := range rule.SourceRanges {
|
||||
source, err := netip.ParsePrefix(sourceRange)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse source range: %w", err)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
}
|
||||
|
||||
var destination netip.Prefix
|
||||
if rule.IsDynamic {
|
||||
destination = getDefault(sources[0])
|
||||
} else {
|
||||
var err error
|
||||
destination, err = netip.ParsePrefix(rule.Destination)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse destination: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(rule.Action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid action: %w", err)
|
||||
}
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
|
||||
return id.RuleID(addedRule.GetRuleID()), nil
|
||||
d.rulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
) (string, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
@@ -264,16 +179,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
@@ -295,7 +210,7 @@ func (d *DefaultManager) addInRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -306,7 +221,7 @@ func (d *DefaultManager) addInRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -324,7 +239,7 @@ func (d *DefaultManager) addOutRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -335,7 +250,7 @@ func (d *DefaultManager) addOutRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
@@ -344,21 +259,21 @@ func (d *DefaultManager) addOutRules(
|
||||
return append(rules, rule...), nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getRuleID(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) id.RuleID {
|
||||
) string {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
@@ -368,7 +283,7 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
// but other has port definitions or has drop policy.
|
||||
func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
@@ -376,14 +291,14 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
}
|
||||
|
||||
type protoMatch map[mgmProto.RuleProtocol]map[string]int
|
||||
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
|
||||
|
||||
in := protoMatch{}
|
||||
out := protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
@@ -393,7 +308,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
|
||||
if drop {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
return
|
||||
@@ -421,7 +336,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
// calculate squash for different directions
|
||||
if r.Direction == mgmProto.RuleDirection_IN {
|
||||
if r.Direction == mgmProto.FirewallRule_IN {
|
||||
addRuleToCalculationMap(i, r, in)
|
||||
} else {
|
||||
addRuleToCalculationMap(i, r, out)
|
||||
@@ -430,14 +345,14 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.RuleProtocol{
|
||||
mgmProto.RuleProtocol_ALL,
|
||||
mgmProto.RuleProtocol_ICMP,
|
||||
mgmProto.RuleProtocol_TCP,
|
||||
mgmProto.RuleProtocol_UDP,
|
||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
||||
mgmProto.FirewallRule_ALL,
|
||||
mgmProto.FirewallRule_ICMP,
|
||||
mgmProto.FirewallRule_TCP,
|
||||
mgmProto.FirewallRule_UDP,
|
||||
}
|
||||
|
||||
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
|
||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||
// don't squash if :
|
||||
@@ -450,12 +365,12 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: direction,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: protocol,
|
||||
})
|
||||
squashedProtocols[protocol] = struct{}{}
|
||||
|
||||
if protocol == mgmProto.RuleProtocol_ALL {
|
||||
if protocol == mgmProto.FirewallRule_ALL {
|
||||
// if we have ALL traffic type squashed rule
|
||||
// it allows all other type of traffic, so we can stop processing
|
||||
break
|
||||
@@ -463,11 +378,11 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
}
|
||||
|
||||
squash(in, mgmProto.RuleDirection_IN)
|
||||
squash(out, mgmProto.RuleDirection_OUT)
|
||||
squash(in, mgmProto.FirewallRule_IN)
|
||||
squash(out, mgmProto.FirewallRule_OUT)
|
||||
|
||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
||||
return squashedRules, squashedProtocols
|
||||
}
|
||||
|
||||
@@ -497,26 +412,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
return firewall.ProtocolTCP, nil
|
||||
case mgmProto.RuleProtocol_UDP:
|
||||
case mgmProto.FirewallRule_UDP:
|
||||
return firewall.ProtocolUDP, nil
|
||||
case mgmProto.RuleProtocol_ICMP:
|
||||
case mgmProto.FirewallRule_ICMP:
|
||||
return firewall.ProtocolICMP, nil
|
||||
case mgmProto.RuleProtocol_ALL:
|
||||
case mgmProto.FirewallRule_ALL:
|
||||
return firewall.ProtocolALL, nil
|
||||
default:
|
||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
@@ -527,41 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
|
||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||
}
|
||||
|
||||
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
|
||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
||||
switch action {
|
||||
case mgmProto.RuleAction_ACCEPT:
|
||||
case mgmProto.FirewallRule_ACCEPT:
|
||||
return firewall.ActionAccept, nil
|
||||
case mgmProto.RuleAction_DROP:
|
||||
case mgmProto.FirewallRule_DROP:
|
||||
return firewall.ActionDrop, nil
|
||||
default:
|
||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||
}
|
||||
}
|
||||
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||
if portInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if portInfo.GetPort() != 0 {
|
||||
return &firewall.Port{
|
||||
Values: []int{int(portInfo.GetPort())},
|
||||
}
|
||||
}
|
||||
|
||||
if portInfo.GetRange() != nil {
|
||||
return &firewall.Port{
|
||||
IsRange: true,
|
||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||
if prefix.Addr().Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_DROP,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
Port: "53",
|
||||
},
|
||||
},
|
||||
@@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
for id := range acl.peerRulesPairs {
|
||||
existedPairs[id.GetRuleID()] = struct{}{}
|
||||
for id := range acl.rulesPairs {
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
@@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
|
||||
networkMap.FirewallRules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_ICMP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_DROP,
|
||||
Protocol: mgmProto.FirewallRule_ICMP,
|
||||
},
|
||||
)
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied")
|
||||
return
|
||||
}
|
||||
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
if _, ok := existedPairs[id.GetRuleID()]; ok {
|
||||
for id := range acl.rulesPairs {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
@@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = true
|
||||
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||
if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap)
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
})
|
||||
@@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_IN:
|
||||
case r.Direction != mgmProto.FirewallRule_IN:
|
||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
@@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
||||
case r.Direction != mgmProto.FirewallRule_OUT:
|
||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
@@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
if len(acl.rulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,7 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
iface "github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
||||
@@ -78,7 +77,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
|
||||
}
|
||||
|
||||
// SetFilter mocks base method.
|
||||
func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error {
|
||||
func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetFilter", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
||||
@@ -3,7 +3,6 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -181,7 +180,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
|
||||
continue
|
||||
}
|
||||
|
||||
return TokenInfo{}, errors.New(tokenResponse.ErrorDescription)
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
|
||||
@@ -69,11 +69,6 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
@@ -86,7 +81,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -144,18 +143,6 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
cert := p.providerConfig.ClientCertPair
|
||||
if cert != nil {
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
},
|
||||
}
|
||||
sslClient := &http.Client{Transport: tr}
|
||||
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
token, err := p.handleRequest(req)
|
||||
if err != nil {
|
||||
renderPKCEFlowTmpl(w, err)
|
||||
|
||||
@@ -2,7 +2,6 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -16,9 +15,9 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -58,8 +57,6 @@ type ConfigInput struct {
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -105,23 +102,11 @@ type Config struct {
|
||||
|
||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||
DNSRouteInterval time.Duration
|
||||
//Path to a certificate used for mTLS authentication
|
||||
ClientCertPath string
|
||||
|
||||
//Path to corresponding private key of ClientCertPath
|
||||
ClientCertKeyPath string
|
||||
|
||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
if configFileIsExists(configPath) {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
@@ -164,17 +149,13 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||
err = WriteOutConfig(input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err := util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
@@ -404,26 +385,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertPath != "" {
|
||||
config.ClientCertPath = input.ClientCertPath
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
|
||||
if err != nil {
|
||||
log.Error("Failed to load mTLS cert/key pair: ", err)
|
||||
} else {
|
||||
config.ClientCertKeyPair = &cert
|
||||
log.Info("Loaded client mTLS cert/key pair")
|
||||
}
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -17,18 +17,15 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -58,20 +55,22 @@ func NewConnectClient(
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run() error {
|
||||
return c.run(MobileDependency{}, nil, nil)
|
||||
return c.run(MobileDependency{}, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(
|
||||
probes *ProbeHolder,
|
||||
runningChan chan error,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
return c.run(MobileDependency{}, probes, runningChan)
|
||||
return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
func (c *ConnectClient) RunOnAndroid(
|
||||
tunAdapter device.TunAdapter,
|
||||
tunAdapter iface.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []string,
|
||||
@@ -85,7 +84,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -101,13 +100,15 @@ func (c *ConnectClient) RunOniOS(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(
|
||||
mobileDependency MobileDependency,
|
||||
probes *ProbeHolder,
|
||||
runningChan chan error,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -159,11 +160,12 @@ func (c *ConnectClient) run(
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
runningChanOpen := true
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if c.isContextCancelled() {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
state.Set(StatusConnecting)
|
||||
@@ -185,7 +187,8 @@ func (c *ConnectClient) run(
|
||||
|
||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||
defer func() {
|
||||
if err = mgmClient.Close(); err != nil {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
@@ -196,7 +199,6 @@ func (c *ConnectClient) run(
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
_ = c.Stop()
|
||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||
}
|
||||
return wrapErr(err)
|
||||
@@ -206,9 +208,10 @@ func (c *ConnectClient) run(
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
signalURL := fmt.Sprintf("%s://%s",
|
||||
@@ -241,23 +244,6 @@ func (c *ConnectClient) run(
|
||||
|
||||
c.statusRecorder.MarkSignalConnected()
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
||||
if len(relayURLs) > 0 {
|
||||
if token != nil {
|
||||
if err := relayManager.UpdateToken(token); err != nil {
|
||||
log.Errorf("failed to update token: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
}
|
||||
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
||||
if err = relayManager.Serve(); err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
}
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
@@ -269,11 +255,11 @@ func (c *ConnectClient) run(
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := c.engine.Start(); err != nil {
|
||||
err = c.engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
@@ -281,26 +267,17 @@ func (c *ConnectClient) run(
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil && runningChanOpen {
|
||||
runningChan <- nil
|
||||
close(runningChan)
|
||||
runningChanOpen = false
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.engine = nil
|
||||
}
|
||||
c.engineMutex.Unlock()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
err = c.engine.Stop()
|
||||
if err != nil {
|
||||
log.Errorf("failed stopping engine %v", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
log.Info("stopped NetBird client")
|
||||
|
||||
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
|
||||
@@ -316,31 +293,13 @@ func (c *ConnectClient) run(
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
_ = c.Stop()
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
|
||||
if relayCfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
token := &hmac.Token{
|
||||
Payload: relayCfg.GetTokenPayload(),
|
||||
Signature: relayCfg.GetTokenSignature(),
|
||||
}
|
||||
|
||||
return relayCfg.GetUrls(), token
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Engine() *Engine {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
var e *Engine
|
||||
c.engineMutex.Lock()
|
||||
e = c.engine
|
||||
@@ -348,28 +307,6 @@ func (c *ConnectClient) Engine() *Engine {
|
||||
return e
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.engineMutex.Lock()
|
||||
defer c.engineMutex.Unlock()
|
||||
|
||||
if c.engine == nil {
|
||||
return nil
|
||||
}
|
||||
return c.engine.Stop()
|
||||
}
|
||||
|
||||
func (c *ConnectClient) isContextCancelled() bool {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
nm := false
|
||||
@@ -460,43 +397,19 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
return notifier
|
||||
}
|
||||
|
||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||
func freePort(initPort int) (int, error) {
|
||||
func freePort(start int) (int, error) {
|
||||
addr := net.UDPAddr{}
|
||||
if initPort == 0 {
|
||||
initPort = iface.DefaultWgPort
|
||||
if start == 0 {
|
||||
start = iface.DefaultWgPort
|
||||
}
|
||||
|
||||
addr.Port = initPort
|
||||
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err == nil {
|
||||
closeConnWithLog(conn)
|
||||
return initPort, nil
|
||||
}
|
||||
|
||||
// if the port is already in use, ask the system for a free port
|
||||
addr.Port = 0
|
||||
conn, err = net.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to get a free port: %v", err)
|
||||
}
|
||||
|
||||
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, errors.New("wrong address type when getting a free port")
|
||||
}
|
||||
closeConnWithLog(conn)
|
||||
return udpAddr.Port, nil
|
||||
}
|
||||
|
||||
func closeConnWithLog(conn *net.UDPConn) {
|
||||
startClosing := time.Now()
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("closing probe port %d failed: %v. NetBird will still attempt to use this port for connection.", conn.LocalAddr().(*net.UDPAddr).Port, err)
|
||||
}
|
||||
if time.Since(startClosing) > time.Second {
|
||||
log.Warnf("closing the testing port %d took %s. Usually it is safe to ignore, but continuous warnings may indicate a problem.", conn.LocalAddr().(*net.UDPAddr).Port, time.Since(startClosing))
|
||||
for x := start; x <= 65535; x++ {
|
||||
addr.Port = x
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Close()
|
||||
return x, nil
|
||||
}
|
||||
return 0, errors.New("no free ports")
|
||||
}
|
||||
|
||||
@@ -7,55 +7,51 @@ import (
|
||||
|
||||
func Test_freePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port int
|
||||
want int
|
||||
shouldMatch bool
|
||||
name string
|
||||
port int
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "not provided, fallback to default",
|
||||
port: 0,
|
||||
want: 51820,
|
||||
shouldMatch: true,
|
||||
name: "available",
|
||||
port: 51820,
|
||||
want: 51820,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "provided and available",
|
||||
port: 51821,
|
||||
want: 51821,
|
||||
shouldMatch: true,
|
||||
name: "notavailable",
|
||||
port: 51830,
|
||||
want: 51831,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "provided and not available",
|
||||
port: 51830,
|
||||
want: 51830,
|
||||
shouldMatch: false,
|
||||
name: "noports",
|
||||
port: 65535,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
defer func(c1 *net.UDPConn) {
|
||||
_ = c1.Close()
|
||||
}(c1)
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := freePort(tt.port)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("got an error while getting free port: %v", err)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.shouldMatch && got != tt.want {
|
||||
t.Errorf("got a different port %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
if !tt.shouldMatch && got == tt.want {
|
||||
t.Errorf("got the same port %v, want a different port", tt.want)
|
||||
if got != tt.want {
|
||||
t.Errorf("freePort() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,12 +15,6 @@ type hostManager interface {
|
||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
Domains []string
|
||||
ServerIP string
|
||||
ServerPort int
|
||||
}
|
||||
|
||||
type HostDNSConfig struct {
|
||||
Domains []DomainConfig `json:"domains"`
|
||||
RouteAll bool `json:"routeAll"`
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
const (
|
||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
|
||||
keyServerAddresses = "ServerAddresses"
|
||||
@@ -29,12 +28,12 @@ const (
|
||||
scutilPath = "/usr/sbin/scutil"
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
)
|
||||
|
||||
type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
systemDNSSettings SystemDNSSettings
|
||||
// primaryServiceID primary interface in the system. AKA the interface with the default route
|
||||
primaryServiceID string
|
||||
createdKeys map[string]struct{}
|
||||
}
|
||||
|
||||
func newHostManager() (hostManager, error) {
|
||||
@@ -50,6 +49,20 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
|
||||
if config.RouteAll {
|
||||
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns setup for all: %w", err)
|
||||
}
|
||||
} else if s.primaryServiceID != "" {
|
||||
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("remote key from system config: %w", err)
|
||||
}
|
||||
s.primaryServiceID = ""
|
||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
@@ -60,19 +73,6 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
err = s.recordSystemDNSSettings(true)
|
||||
if err != nil {
|
||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
searchDomains = append(searchDomains, "\"\"")
|
||||
err = s.addLocalDNS()
|
||||
if err != nil {
|
||||
log.Infof("failed to enable split DNS")
|
||||
}
|
||||
}
|
||||
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
@@ -110,17 +110,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreHostDNS() error {
|
||||
keys := s.getRemovableKeysWithDefaults()
|
||||
for _, key := range keys {
|
||||
lines := ""
|
||||
for key := range s.createdKeys {
|
||||
lines += buildRemoveKeyOperation(key)
|
||||
keyType := "search"
|
||||
if strings.Contains(key, matchSuffix) {
|
||||
keyType = "match"
|
||||
}
|
||||
log.Infof("removing %s domains from system", keyType)
|
||||
err := s.removeKeyFromSystemConfig(key)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove %s domains from system: %s", keyType, err)
|
||||
}
|
||||
}
|
||||
if s.primaryServiceID != "" {
|
||||
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||
log.Infof("restoring DNS resolver configuration for system")
|
||||
}
|
||||
_, err := runSystemConfigCommand(wrapCommand(lines))
|
||||
if err != nil {
|
||||
log.Errorf("got an error while cleaning the system configuration: %s", err)
|
||||
return fmt.Errorf("clean system: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
@@ -130,19 +136,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
if len(s.createdKeys) == 0 {
|
||||
// return defaults for startup calls
|
||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.createdKeys))
|
||||
for key := range s.createdKeys {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
@@ -155,97 +148,6 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addLocalDNS() error {
|
||||
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
|
||||
err := s.recordSystemDNSSettings(true)
|
||||
log.Errorf("Unable to get system DNS configuration")
|
||||
return err
|
||||
}
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
|
||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Info("Not enabling local DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
|
||||
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
|
||||
return nil
|
||||
}
|
||||
|
||||
systemDNSSettings, err := s.getSystemDNSSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get current DNS config: %w", err)
|
||||
}
|
||||
s.systemDNSSettings = systemDNSSettings
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
primaryServiceKey, _, err := s.getPrimaryService()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||
}
|
||||
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
|
||||
line := buildCommandLine("show", dnsServiceKey, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
|
||||
b, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
|
||||
}
|
||||
|
||||
var dnsSettings SystemDNSSettings
|
||||
inSearchDomainsArray := false
|
||||
inServerAddressesArray := false
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
switch {
|
||||
case strings.HasPrefix(line, "DomainName :"):
|
||||
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
|
||||
case line == "SearchDomains : <array> {":
|
||||
inSearchDomainsArray = true
|
||||
continue
|
||||
case line == "ServerAddresses : <array> {":
|
||||
inServerAddressesArray = true
|
||||
continue
|
||||
case line == "}":
|
||||
inSearchDomainsArray = false
|
||||
inServerAddressesArray = false
|
||||
}
|
||||
|
||||
if inSearchDomainsArray {
|
||||
searchDomain := strings.Split(line, " : ")[1]
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
|
||||
dnsSettings.ServerIP = address
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return dnsSettings, err
|
||||
}
|
||||
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = 53
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
@@ -292,6 +194,23 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||
}
|
||||
|
||||
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
|
||||
s.primaryServiceID = primaryServiceKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
line := buildCommandLine("show", globalIPv4State, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
@@ -320,6 +239,19 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
return primaryService, router, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
|
||||
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
||||
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
||||
stdinCommands := wrapCommand(addDomainCommand)
|
||||
_, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying dns setup, error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
func TestResponseWriterLocalAddr(t *testing.T) {
|
||||
|
||||
@@ -94,7 +94,7 @@ func NewDefaultServer(
|
||||
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(wgInterface)
|
||||
dnsService = newServiceViaMemory(wgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@@ -130,7 +130,7 @@ func NewDefaultServerIos(
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -15,18 +15,16 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
filter iface.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
@@ -45,11 +43,11 @@ func (w *mocWGIface) ToInterface() *net.Interface {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() device.PacketFilter {
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetDevice() *device.FilteredDevice {
|
||||
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
@@ -61,13 +59,13 @@ func (w *mocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *mocWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||
w.filter = filter
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
||||
return configurer.WGStats{}, nil
|
||||
func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
|
||||
return iface.WGStats{}, nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
@@ -536,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
service: newServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
|
||||
@@ -128,9 +128,6 @@ func (s *serviceViaListener) RuntimeIP() string {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ServiceViaMemory struct {
|
||||
type serviceViaMemory struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP string
|
||||
@@ -22,8 +22,8 @@ type ServiceViaMemory struct {
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
s := &ServiceViaMemory{
|
||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
s := &serviceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
func (s *serviceViaMemory) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
@@ -52,7 +52,7 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
func (s *serviceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
@@ -67,23 +67,23 @@ func (s *ServiceViaMemory) Stop() {
|
||||
s.listenerIsRunning = false
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
|
||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RuntimePort() int {
|
||||
func (s *serviceViaMemory) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RuntimeIP() string {
|
||||
func (s *serviceViaMemory) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
|
||||
@@ -24,7 +24,7 @@ const (
|
||||
probeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
const testRecord = "."
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
@@ -42,7 +42,6 @@ type upstreamResolverBase struct {
|
||||
upstreamServers []string
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
successCount atomic.Int32
|
||||
failsTillDeact int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
@@ -125,7 +124,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
u.successCount.Add(1)
|
||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
@@ -174,11 +172,6 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
default:
|
||||
}
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
|
||||
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
|
||||
return
|
||||
}
|
||||
|
||||
var success bool
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
@@ -190,7 +183,7 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(upstream, 500*time.Millisecond)
|
||||
err := u.testNameserver(upstream)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, err)
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
@@ -231,7 +224,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(upstream, probeTimeout); err != nil {
|
||||
if err := u.testNameserver(upstream); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
@@ -251,7 +244,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
|
||||
u.failsCount.Store(0)
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.disabled = false
|
||||
}
|
||||
@@ -273,14 +265,13 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
func (u *upstreamResolverBase) testNameserver(server string) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
@@ -4,7 +4,6 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -18,9 +17,9 @@ import (
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
interfaceName string
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
iIndex int
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
@@ -33,11 +32,17 @@ func newUpstreamResolver(
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
interfaceName: interfaceName,
|
||||
iIndex: index,
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
@@ -48,7 +53,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
client := &dns.Client{}
|
||||
upstreamHost, _, err := net.SplitHostPort(upstream)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||
log.Errorf("error while parsing upstream host: %s", err)
|
||||
}
|
||||
|
||||
timeout := upstreamTimeout
|
||||
@@ -60,35 +65,26 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP := net.ParseIP(upstreamHost)
|
||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
client = u.getClientPrivate(timeout)
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return client.Exchange(r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: ip,
|
||||
IP: u.lIP,
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var operr error
|
||||
fn := func(s uintptr) {
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
|
||||
}
|
||||
|
||||
if err := c.Control(fn); err != nil {
|
||||
@@ -105,7 +101,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
|
||||
client := &dns.Client{
|
||||
Dialer: dialer,
|
||||
}
|
||||
return client, nil
|
||||
return client
|
||||
}
|
||||
|
||||
func getInterfaceIndex(interfaceName string) (int, error) {
|
||||
|
||||
@@ -5,9 +5,7 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
@@ -16,7 +14,7 @@ type WGIface interface {
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetStats(peerKey string) (iface.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetStats(peerKey string) (iface.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
@@ -23,12 +22,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
@@ -39,11 +34,11 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
@@ -106,8 +101,7 @@ type EngineConfig struct {
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
type Engine struct {
|
||||
// signal is a Signal Service client
|
||||
signal signal.Client
|
||||
signaler *peer.Signaler
|
||||
signal signal.Client
|
||||
// mgmClient is a Management Service client
|
||||
mgmClient mgm.Client
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
@@ -128,8 +122,7 @@ type Engine struct {
|
||||
// STUNs is a list of STUN servers used by ICE
|
||||
STUNs []*stun.URI
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
TURNs []*stun.URI
|
||||
stunTurn atomic.Value
|
||||
TURNs []*stun.URI
|
||||
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
@@ -141,7 +134,7 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface *iface.WGIface
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
@@ -162,12 +155,15 @@ type Engine struct {
|
||||
|
||||
dnsServer dns.Server
|
||||
|
||||
probes *ProbeHolder
|
||||
mgmProbe *Probe
|
||||
signalProbe *Probe
|
||||
relayProbe *Probe
|
||||
wgProbe *Probe
|
||||
|
||||
wgConnWorker sync.WaitGroup
|
||||
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -182,7 +178,6 @@ func NewEngine(
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
@@ -193,11 +188,13 @@ func NewEngine(
|
||||
clientCancel,
|
||||
signalClient,
|
||||
mgmClient,
|
||||
relayManager,
|
||||
config,
|
||||
mobileDep,
|
||||
statusRecorder,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
checks,
|
||||
)
|
||||
}
|
||||
@@ -208,20 +205,21 @@ func NewEngineWithProbes(
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
probes *ProbeHolder,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
|
||||
return &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: signalClient,
|
||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||
mgmClient: mgmClient,
|
||||
relayManager: relayManager,
|
||||
peerConns: make(map[string]*peer.Conn),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
@@ -231,51 +229,43 @@ func NewEngineWithProbes(
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
probes: probes,
|
||||
mgmProbe: mgmProbe,
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
wgProbe: wgProbe,
|
||||
checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) Stop() error {
|
||||
if e == nil {
|
||||
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
|
||||
log.Debugf("tried stopping engine that is nil")
|
||||
return nil
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// stopping network monitor first to avoid starting the engine again
|
||||
if e.networkMonitor != nil {
|
||||
e.networkMonitor.Stop()
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = nil
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
e.wgConnWorker.Wait()
|
||||
log.Infof("stopped Netbird Engine")
|
||||
return nil
|
||||
}
|
||||
@@ -300,7 +290,7 @@ func (e *Engine) Start() error {
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
@@ -326,7 +316,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
@@ -475,45 +465,80 @@ func (e *Engine) removePeer(peerKey string) error {
|
||||
conn, exists := e.peerConns[peerKey]
|
||||
if exists {
|
||||
delete(e.peerConns, peerKey)
|
||||
conn.Close()
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case *peer.ConnectionAlreadyClosedError:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
||||
err := s.Send(&sProto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey.String(),
|
||||
Body: &sProto.Body{
|
||||
Type: sProto.Body_CANDIDATE,
|
||||
Payload: candidate.Marshal(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendSignal(message *sProto.Message, s signal.Client) error {
|
||||
return s.Send(message)
|
||||
}
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client,
|
||||
isAnswer bool) error {
|
||||
var t sProto.Body_Type
|
||||
if isAnswer {
|
||||
t = sProto.Body_ANSWER
|
||||
} else {
|
||||
t = sProto.Body_OFFER
|
||||
}
|
||||
|
||||
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.Send(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if update.GetWiretrusteeConfig() != nil {
|
||||
wCfg := update.GetWiretrusteeConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
relayMsg := wCfg.GetRelay()
|
||||
if relayMsg != nil {
|
||||
c := &auth.Token{
|
||||
Payload: relayMsg.GetTokenPayload(),
|
||||
Signature: relayMsg.GetTokenSignature(),
|
||||
}
|
||||
if err := e.relayManager.UpdateToken(c); err != nil {
|
||||
log.Errorf("failed to update relay token: %v", err)
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// todo update relay address in the relay manager
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
@@ -627,7 +652,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||
IP: e.config.WgAddr,
|
||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||
FQDN: conf.GetFqdn(),
|
||||
})
|
||||
|
||||
@@ -712,11 +737,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply ACLs in the beginning to avoid security leaks
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
@@ -783,6 +803,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
@@ -910,13 +934,68 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
conn.Open()
|
||||
e.wgConnWorker.Add(1)
|
||||
go e.connWorker(conn, peerKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
defer e.wgConnWorker.Done()
|
||||
for {
|
||||
|
||||
// randomize starting time a bit
|
||||
min := 500
|
||||
max := 2000
|
||||
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
return
|
||||
case <-time.After(duration):
|
||||
}
|
||||
|
||||
// if peer has been removed -> give up
|
||||
if !e.peerExists(peerKey) {
|
||||
log.Debugf("peer %s doesn't exist anymore, won't retry connection", peerKey)
|
||||
return
|
||||
}
|
||||
|
||||
if !e.signal.Ready() {
|
||||
log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey)
|
||||
continue
|
||||
}
|
||||
|
||||
// we might have received new STUN and TURN servers meanwhile, so update them
|
||||
e.syncMsgMux.Lock()
|
||||
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
err := conn.Open(e.ctx)
|
||||
if err != nil {
|
||||
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
||||
var connectionClosedError *peer.ConnectionClosedError
|
||||
switch {
|
||||
case errors.As(err, &connectionClosedError):
|
||||
// conn has been forced to close, so we exit the loop
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) peerExists(peerKey string) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
_, ok := e.peerConns[peerKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
|
||||
log.Debugf("creating peer connection %s", pubKey)
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
|
||||
wgConfig := peer.WgConfig{
|
||||
RemoteKey: pubKey,
|
||||
@@ -949,29 +1028,52 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
// randomize connection timeout
|
||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||
config := peer.ConnConfig{
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
ICEConfig: peer.ICEConfig{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
},
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
StunTurn: stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
Timeout: timeout,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgPubKey, err := wgtypes.ParseKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signalOffer := func(offerAnswer peer.OfferAnswer) error {
|
||||
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
|
||||
}
|
||||
|
||||
signalCandidate := func(candidate ice.Candidate) error {
|
||||
return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
|
||||
}
|
||||
|
||||
signalAnswer := func(offerAnswer peer.OfferAnswer) error {
|
||||
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
|
||||
}
|
||||
|
||||
peerConn.SetSignalCandidate(signalCandidate)
|
||||
peerConn.SetSignalOffer(signalOffer)
|
||||
peerConn.SetSignalAnswer(signalAnswer)
|
||||
peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
|
||||
return sendSignal(message, e.signal)
|
||||
})
|
||||
|
||||
if e.rpManager != nil {
|
||||
|
||||
peerConn.SetOnConnected(e.rpManager.OnConnected)
|
||||
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
|
||||
}
|
||||
@@ -1014,7 +1116,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
Version: msg.GetBody().GetNetBirdVersion(),
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
||||
})
|
||||
case sProto.Body_ANSWER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
@@ -1037,7 +1138,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
Version: msg.GetBody().GetNetBirdVersion(),
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
||||
})
|
||||
case sProto.Body_CANDIDATE:
|
||||
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
|
||||
@@ -1046,7 +1146,7 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
||||
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
||||
case sProto.Body_MODE:
|
||||
}
|
||||
|
||||
@@ -1123,12 +1223,21 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
}
|
||||
e.wgInterface = nil
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
@@ -1167,15 +1276,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
var mArgs *device.MobileIFaceArguments
|
||||
var mArgs *iface.MobileIFaceArguments
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
mArgs = &iface.MobileIFaceArguments{
|
||||
TunAdapter: e.mobileDep.TunAdapter,
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
case "ios":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
mArgs = &iface.MobileIFaceArguments{
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
default:
|
||||
@@ -1291,27 +1400,24 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
}
|
||||
|
||||
func (e *Engine) receiveProbeEvents() {
|
||||
if e.probes == nil {
|
||||
return
|
||||
}
|
||||
if e.probes.SignalProbe != nil {
|
||||
go e.probes.SignalProbe.Receive(e.ctx, func() bool {
|
||||
if e.signalProbe != nil {
|
||||
go e.signalProbe.Receive(e.ctx, func() bool {
|
||||
healthy := e.signal.IsHealthy()
|
||||
log.Debugf("received signal probe request, healthy: %t", healthy)
|
||||
return healthy
|
||||
})
|
||||
}
|
||||
|
||||
if e.probes.MgmProbe != nil {
|
||||
go e.probes.MgmProbe.Receive(e.ctx, func() bool {
|
||||
if e.mgmProbe != nil {
|
||||
go e.mgmProbe.Receive(e.ctx, func() bool {
|
||||
healthy := e.mgmClient.IsHealthy()
|
||||
log.Debugf("received management probe request, healthy: %t", healthy)
|
||||
return healthy
|
||||
})
|
||||
}
|
||||
|
||||
if e.probes.RelayProbe != nil {
|
||||
go e.probes.RelayProbe.Receive(e.ctx, func() bool {
|
||||
if e.relayProbe != nil {
|
||||
go e.relayProbe.Receive(e.ctx, func() bool {
|
||||
healthy := true
|
||||
|
||||
results := append(e.probeSTUNs(), e.probeTURNs()...)
|
||||
@@ -1330,13 +1436,13 @@ func (e *Engine) receiveProbeEvents() {
|
||||
})
|
||||
}
|
||||
|
||||
if e.probes.WgProbe != nil {
|
||||
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
||||
if e.wgProbe != nil {
|
||||
go e.wgProbe.Receive(e.ctx, func() bool {
|
||||
log.Debug("received wg probe request")
|
||||
|
||||
for _, peer := range e.peerConns {
|
||||
key := peer.GetKey()
|
||||
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
|
||||
wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
@@ -1360,16 +1466,12 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (e *Engine) restartEngine() {
|
||||
log.Info("restarting engine")
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
log.Infof("cancelling client, engine will be recreated")
|
||||
e.clientCancel()
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
@@ -1391,17 +1493,16 @@ func (e *Engine) startNetworkMonitor() {
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
@@ -1426,21 +1527,6 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
if e.dnsServer == nil {
|
||||
return
|
||||
}
|
||||
e.dnsServer.Stop()
|
||||
e.dnsServer = nil
|
||||
err := fmt.Errorf("DNS server stopped")
|
||||
nsGroupStates := e.statusRecorder.GetDNSStates()
|
||||
for i := range nsGroupStates {
|
||||
nsGroupStates[i].Enabled = false
|
||||
nsGroupStates[i].Error = err
|
||||
}
|
||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -24,21 +24,18 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgmt "github.com/netbirdio/netbird/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
@@ -60,12 +57,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
@@ -81,23 +72,13 @@ func TestEngine_SSH(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil,
|
||||
)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -226,29 +207,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil)
|
||||
|
||||
wgIface := &iface.MockWGIface{
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@@ -430,8 +403,8 @@ func TestEngine_Sync(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
@@ -590,8 +563,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
@@ -761,8 +733,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
@@ -823,6 +794,20 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
@@ -832,7 +817,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -859,8 +844,6 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
@@ -1026,8 +1009,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
@@ -1042,7 +1024,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
@@ -1055,17 +1037,12 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Relay: &server.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &server.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
@@ -1080,7 +1057,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -1092,17 +1069,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// MobileDependency collect all dependencies for mobile platform
|
||||
type MobileDependency struct {
|
||||
// Android only
|
||||
TunAdapter device.TunAdapter
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
NetworkChangeListener listener.NetworkChangeListener
|
||||
HostDNSAddresses []string
|
||||
|
||||
@@ -4,7 +4,6 @@ package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -22,17 +21,8 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Debugf("Network monitor: closed routing socket: %v", err)
|
||||
if err := unix.Close(fd); err != nil {
|
||||
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -44,13 +34,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -61,11 +49,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
if !route.Dst.Addr().IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -3,73 +3,243 @@ package networkmonitor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
const (
|
||||
unreachable = 0
|
||||
incomplete = 1
|
||||
probe = 2
|
||||
delay = 3
|
||||
stale = 4
|
||||
reachable = 5
|
||||
permanent = 6
|
||||
tbd = 7
|
||||
)
|
||||
|
||||
const interval = 10 * time.Second
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route monitor: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := routeMonitor.Stop(); err != nil {
|
||||
log.Errorf("Network monitor: failed to stop route monitor: %v", err)
|
||||
var neighborv4, neighborv6 *systemops.Neighbor
|
||||
{
|
||||
initialNeighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
|
||||
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
|
||||
}
|
||||
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
case route := <-routeMonitor.RouteUpdates():
|
||||
if route.Destination.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if routeChanged(route, nexthopv4, nexthopv6, callback) {
|
||||
break
|
||||
case <-ticker.C:
|
||||
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
if isSoftInterface(intf) {
|
||||
log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
|
||||
return false
|
||||
}
|
||||
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
|
||||
if n, ok := initialNeighbors[nexthop.IP]; ok &&
|
||||
n.State != unreachable &&
|
||||
n.State != incomplete &&
|
||||
n.State != tbd {
|
||||
return &n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func changed(
|
||||
nexthopv4 systemops.Nexthop,
|
||||
neighborv4 *systemops.Neighbor,
|
||||
nexthopv6 systemops.Nexthop,
|
||||
neighborv6 *systemops.Neighbor,
|
||||
) bool {
|
||||
neighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
log.Errorf("network monitor: error fetching current neighbors: %v", err)
|
||||
return false
|
||||
}
|
||||
if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
|
||||
return true
|
||||
}
|
||||
|
||||
switch route.Type {
|
||||
case systemops.RouteModified:
|
||||
// TODO: get routing table to figure out if our route is affected for modified routes
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
routes, err := getRoutes()
|
||||
if err != nil {
|
||||
log.Errorf("network monitor: error fetching current routes: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
|
||||
return true
|
||||
case systemops.RouteAdded:
|
||||
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
|
||||
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
}
|
||||
case systemops.RouteDeleted:
|
||||
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isSoftInterface(name string) bool {
|
||||
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
|
||||
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
|
||||
if !nexthop.IP.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
unspec := getUnspecifiedPrefix(nexthop.IP)
|
||||
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
|
||||
|
||||
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
|
||||
|
||||
if !foundMatchingRoute {
|
||||
logRouteChange(nexthop.IP, intf)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
|
||||
if ip.Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
var defaultRoutes []string
|
||||
foundMatchingRoute := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Destination == unspec {
|
||||
routeInfo := formatRouteInfo(r)
|
||||
defaultRoutes = append(defaultRoutes, routeInfo)
|
||||
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
|
||||
foundMatchingRoute = true
|
||||
log.Debugf("network monitor: found matching default route: %s", routeInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultRoutes, foundMatchingRoute
|
||||
}
|
||||
|
||||
func formatRouteInfo(r systemops.Route) string {
|
||||
newIntf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
newIntf = r.Interface.Name
|
||||
}
|
||||
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
|
||||
}
|
||||
|
||||
func logRouteChange(ip netip.Addr, intf *net.Interface) {
|
||||
oldIntf := "<nil>"
|
||||
if intf != nil {
|
||||
oldIntf = intf.Name
|
||||
}
|
||||
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
|
||||
if neighbor == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||
if n, ok := neighbors[nexthop.IP]; ok {
|
||||
if n.State == unreachable || n.State == incomplete {
|
||||
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||
return true
|
||||
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||
log.Infof(
|
||||
"network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
|
||||
neighbor.IPAddress,
|
||||
neighbor.LinkLayerAddress,
|
||||
neighbor.InterfaceAlias,
|
||||
neighbor.InterfaceIndex,
|
||||
n.InterfaceAlias,
|
||||
n.InterfaceIndex,
|
||||
stateFromInt(n.State),
|
||||
)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
entries, err := systemops.GetNeighbors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
|
||||
for _, entry := range entries {
|
||||
neighbours[entry.IPAddress] = entry
|
||||
}
|
||||
|
||||
return neighbours, nil
|
||||
}
|
||||
|
||||
func getRoutes() ([]systemops.Route, error) {
|
||||
entries, err := systemops.GetRoutes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func stateFromInt(state uint8) string {
|
||||
switch state {
|
||||
case unreachable:
|
||||
return "unreachable"
|
||||
case incomplete:
|
||||
return "incomplete"
|
||||
case probe:
|
||||
return "probe"
|
||||
case delay:
|
||||
return "delay"
|
||||
case stale:
|
||||
return "stale"
|
||||
case reachable:
|
||||
return "reachable"
|
||||
case permanent:
|
||||
return "permanent"
|
||||
case tbd:
|
||||
return "tbd"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func compareIntf(a, b *net.Interface) int {
|
||||
if a == nil && b == nil {
|
||||
return 0
|
||||
}
|
||||
if a == nil {
|
||||
return -1
|
||||
}
|
||||
if b == nil {
|
||||
return 1
|
||||
}
|
||||
return a.Index - b.Index
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,6 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
const (
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
@@ -16,34 +12,7 @@ const (
|
||||
)
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
||||
type AtomicConnStatus struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
||||
func NewAtomicConnStatus() *AtomicConnStatus {
|
||||
acs := &AtomicConnStatus{}
|
||||
acs.Set(StatusDisconnected)
|
||||
return acs
|
||||
}
|
||||
|
||||
// Get returns the current connection status
|
||||
func (acs *AtomicConnStatus) Get() ConnStatus {
|
||||
return ConnStatus(acs.status.Load())
|
||||
}
|
||||
|
||||
// Set updates the connection status
|
||||
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
||||
acs.status.Store(int32(status))
|
||||
}
|
||||
|
||||
// String returns the string representation of the current status
|
||||
func (acs *AtomicConnStatus) String() string {
|
||||
return acs.Get().String()
|
||||
}
|
||||
type ConnStatus int
|
||||
|
||||
func (s ConnStatus) String() string {
|
||||
switch s {
|
||||
|
||||
@@ -2,33 +2,25 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/pion/stun/v2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var connConf = ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
Timeout: time.Second,
|
||||
LocalWgPort: 51820,
|
||||
ICEConfig: ICEConfig{
|
||||
InterfaceBlackList: nil,
|
||||
},
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
StunTurn: []*stun.URI{},
|
||||
InterfaceBlackList: nil,
|
||||
Timeout: time.Second,
|
||||
LocalWgPort: 51820,
|
||||
}
|
||||
|
||||
func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
@@ -44,11 +36,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -59,11 +51,11 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -71,7 +63,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
<-conn.handshaker.remoteOffersCh
|
||||
<-conn.remoteOffersCh
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@@ -96,11 +88,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -108,7 +100,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
<-conn.handshaker.remoteAnswerCh
|
||||
<-conn.remoteAnswerCh
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@@ -132,42 +124,62 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables := []struct {
|
||||
name string
|
||||
statusIce ConnStatus
|
||||
statusRelay ConnStatus
|
||||
want ConnStatus
|
||||
name string
|
||||
status ConnStatus
|
||||
want ConnStatus
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
|
||||
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
|
||||
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
|
||||
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
|
||||
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
|
||||
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
|
||||
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
|
||||
{"StatusConnected", StatusConnected, StatusConnected},
|
||||
{"StatusDisconnected", StatusDisconnected, StatusDisconnected},
|
||||
{"StatusConnecting", StatusConnecting, StatusConnecting},
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
t.Run(table.name, func(t *testing.T) {
|
||||
si := NewAtomicConnStatus()
|
||||
si.Set(table.statusIce)
|
||||
conn.statusICE = si
|
||||
|
||||
sr := NewAtomicConnStatus()
|
||||
sr.Set(table.statusRelay)
|
||||
conn.statusRelay = sr
|
||||
conn.status = table.status
|
||||
|
||||
got := conn.Status()
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
<-conn.closeCh
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSignalIsNotReady = errors.New("signal is not ready")
|
||||
)
|
||||
|
||||
// IceCredentials ICE protocol credentials struct
|
||||
type IceCredentials struct {
|
||||
UFrag string
|
||||
Pwd string
|
||||
}
|
||||
|
||||
// OfferAnswer represents a session establishment offer or answer
|
||||
type OfferAnswer struct {
|
||||
IceCredentials IceCredentials
|
||||
// WgListenPort is a remote WireGuard listen port.
|
||||
// This field is used when establishing a direct WireGuard connection without any proxy.
|
||||
// We can set the remote peer's endpoint with this port.
|
||||
WgListenPort int
|
||||
|
||||
// Version of NetBird Agent
|
||||
Version string
|
||||
// RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
|
||||
// This value is the local Rosenpass server public key when sending the message
|
||||
RosenpassPubKey []byte
|
||||
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
|
||||
// This value is the local Rosenpass server address when sending the message
|
||||
RosenpassAddr string
|
||||
|
||||
// relay server address
|
||||
RelaySrvAddress string
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
onNewOfferListeners []func(*OfferAnswer)
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
remoteAnswerCh chan OfferAnswer
|
||||
}
|
||||
|
||||
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
return &Handshaker{
|
||||
ctx: ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
ice: ice,
|
||||
relay: relay,
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen() {
|
||||
for {
|
||||
h.log.Debugf("wait for remote offer confirmation")
|
||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
|
||||
if err != nil {
|
||||
var connectionClosedError *ConnectionClosedError
|
||||
if errors.As(err, &connectionClosedError) {
|
||||
h.log.Tracef("stop handshaker")
|
||||
return
|
||||
}
|
||||
h.log.Errorf("failed to received remote offer confirmation: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
go listener(remoteOfferAnswer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) SendOffer() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.sendOffer()
|
||||
}
|
||||
|
||||
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
|
||||
select {
|
||||
case h.remoteOffersCh <- offer:
|
||||
return true
|
||||
default:
|
||||
h.log.Debugf("OnRemoteOffer skipping message because is not ready")
|
||||
// connection might not be ready yet to receive so we ignore the message
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
select {
|
||||
case h.remoteAnswerCh <- answer:
|
||||
return true
|
||||
default:
|
||||
// connection might not be ready yet to receive so we ignore the message
|
||||
h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
err := h.sendAnswer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &remoteOfferAnswer, nil
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
return &remoteOfferAnswer, nil
|
||||
case <-h.ctx.Done():
|
||||
// closed externally
|
||||
return nil, NewConnectionClosedError(h.config.Key)
|
||||
}
|
||||
}
|
||||
|
||||
// sendOffer prepares local user credentials and signals them to the remote peer
|
||||
func (h *Handshaker) sendOffer() error {
|
||||
if !h.signaler.Ready() {
|
||||
return ErrSignalIsNotReady
|
||||
}
|
||||
|
||||
iceUFrag, icePwd := h.ice.GetLocalUserCredentials()
|
||||
offer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{iceUFrag, icePwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassPubKey,
|
||||
RosenpassAddr: h.config.RosenpassAddr,
|
||||
}
|
||||
|
||||
addr, err := h.relay.RelayInstanceAddress()
|
||||
if err == nil {
|
||||
offer.RelaySrvAddress = addr
|
||||
}
|
||||
|
||||
return h.signaler.SignalOffer(offer, h.config.Key)
|
||||
}
|
||||
|
||||
func (h *Handshaker) sendAnswer() error {
|
||||
h.log.Debugf("sending answer")
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassPubKey,
|
||||
RosenpassAddr: h.config.RosenpassAddr,
|
||||
}
|
||||
addr, err := h.relay.RelayInstanceAddress()
|
||||
if err == nil {
|
||||
answer.RelaySrvAddress = addr
|
||||
}
|
||||
|
||||
err = h.signaler.SignalAnswer(answer, h.config.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user