mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 15:43:47 -04:00
Compare commits
150 Commits
feature/ap
...
userspace-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19178b59ec | ||
|
|
cee4aeea9e | ||
|
|
ca9aca9b19 | ||
|
|
e00a280329 | ||
|
|
fe370e7d8f | ||
|
|
125b5e2b16 | ||
|
|
97d498c59c | ||
|
|
0125cd97d8 | ||
|
|
7d385b8dc3 | ||
|
|
f930ef2ee6 | ||
|
|
48f58d776c | ||
|
|
4d635e3c2f | ||
|
|
771c99a523 | ||
|
|
e20be2397c | ||
|
|
46766e7e24 | ||
|
|
a7ddb8f1f8 | ||
|
|
7335c82553 | ||
|
|
a32ec97911 | ||
|
|
5c05131a94 | ||
|
|
b6abd4b4da | ||
|
|
2605948e01 | ||
|
|
a0ca3edb9f | ||
|
|
0837864cfc | ||
|
|
e3d4f9819f | ||
|
|
da43d33540 | ||
|
|
b951fb4aec | ||
|
|
862d548d4d | ||
|
|
eb2ac039c7 | ||
|
|
790a9ed7df | ||
|
|
2e61ce006d | ||
|
|
3cc485759e | ||
|
|
aafa9c67fc | ||
|
|
69f48db0a3 | ||
|
|
8c965434ae | ||
|
|
78da6b42ad | ||
|
|
1ad2cb5582 | ||
|
|
c619bf5b0c | ||
|
|
9f4db0a953 | ||
|
|
3e836db1d1 | ||
|
|
c01874e9ce | ||
|
|
1b2517ea20 | ||
|
|
3e9f0d57ac | ||
|
|
9b5c0439e9 | ||
|
|
481bbe8513 | ||
|
|
bc7b2c6ba3 | ||
|
|
c6f7a299a9 | ||
|
|
992a6c79b4 | ||
|
|
21a3679590 | ||
|
|
77afcc8454 | ||
|
|
22991b3963 | ||
|
|
78795a4a73 | ||
|
|
ea6c947f5d | ||
|
|
5a82477d48 | ||
|
|
1ffa519387 | ||
|
|
e4a25b6a60 | ||
|
|
6a6b527f24 | ||
|
|
b34887a920 | ||
|
|
b9efda3ce8 | ||
|
|
516de93627 | ||
|
|
8dce13113d | ||
|
|
15f0a665f8 | ||
|
|
a625f90ea8 | ||
|
|
9b5b632ff9 | ||
|
|
0c28099712 | ||
|
|
522dd44bfa | ||
|
|
8154069e77 | ||
|
|
e161a92898 | ||
|
|
3fce8485bb | ||
|
|
1cc88a2190 | ||
|
|
168ea9560e | ||
|
|
1c00870ca6 | ||
|
|
1296ecf96e | ||
|
|
8430c37dd6 | ||
|
|
648b22aca1 | ||
|
|
d31543cb12 | ||
|
|
af46f259ac | ||
|
|
f48e33b395 | ||
|
|
f1ed8599fc | ||
|
|
93f3e1b14b | ||
|
|
01957a305d | ||
|
|
649bfb236b | ||
|
|
706f98c1f1 | ||
|
|
6335ef8b48 | ||
|
|
daf935942c | ||
|
|
409003b4f9 | ||
|
|
9e6e34b42d | ||
|
|
28f5cd523a | ||
|
|
d9905d1a57 | ||
|
|
2060242092 | ||
|
|
5ea39dfe8a | ||
|
|
4a189a87ce | ||
|
|
2bd68efc08 | ||
|
|
6848e1e128 | ||
|
|
668aead4c8 | ||
|
|
f08605a7f1 | ||
|
|
02a3feddb8 | ||
|
|
fe7a2aa263 | ||
|
|
290e6992a8 | ||
|
|
474fb33305 | ||
|
|
766e0cccc9 | ||
|
|
7dfe7e426e | ||
|
|
eaadb75144 | ||
|
|
0b116b3941 | ||
|
|
f69dd6fb62 | ||
|
|
62a20f5f1a | ||
|
|
a6ad4dcf22 | ||
|
|
f26b418e83 | ||
|
|
3ce39905c6 | ||
|
|
d9487a5749 | ||
|
|
979fe6bb6a | ||
|
|
cfa6d09c5e | ||
|
|
a01253c3c8 | ||
|
|
c68be6b61b | ||
|
|
fc799effda | ||
|
|
955b2b98e1 | ||
|
|
9490e9095b | ||
|
|
d711172f67 | ||
|
|
0c2fa38e26 | ||
|
|
88b420da6d | ||
|
|
2930288f2d | ||
|
|
0b9854b2b1 | ||
|
|
f772a21f37 | ||
|
|
e912f2d7c0 | ||
|
|
568d064089 | ||
|
|
911f86ded8 | ||
|
|
bc013e4888 | ||
|
|
2b8092dfad | ||
|
|
c3c6afa37b | ||
|
|
fa27369b59 | ||
|
|
657413b8a6 | ||
|
|
d85e57e819 | ||
|
|
7667886794 | ||
|
|
a12a9ac290 | ||
|
|
782e3f8853 | ||
|
|
ed22d79f04 | ||
|
|
03fd656344 | ||
|
|
18b049cd24 | ||
|
|
2bdb4cb44a | ||
|
|
509b4e2132 | ||
|
|
fb1a10755a | ||
|
|
abbdf20f65 | ||
|
|
43ef64cf67 | ||
|
|
9feaa8d767 | ||
|
|
6a97d44d5d | ||
|
|
d2616544fe | ||
|
|
fad82ee65c | ||
|
|
b43a8c56df | ||
|
|
18316be09a | ||
|
|
1a623943c8 | ||
|
|
4199da4a45 |
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.21-bullseye
|
||||
FROM golang:1.23-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||
"ghcr.io/devcontainers/features/go:1": {
|
||||
"version": "1.21"
|
||||
"version": "1.23"
|
||||
}
|
||||
},
|
||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||
|
||||
3
.github/workflows/golang-test-darwin.yml
vendored
3
.github/workflows/golang-test-darwin.yml
vendored
@@ -44,4 +44,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
prepare: |
|
||||
pkg install -y go
|
||||
pkg install -y go pkgconf xorg
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
|
||||
235
.github/workflows/golang-test-linux.yml
vendored
235
.github/workflows/golang-test-linux.yml
vendored
@@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
build-cache:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -134,9 +134,189 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
test_management:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
|
||||
benchmark:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
|
||||
api_benchmark:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||
|
||||
api_integration_test:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -183,56 +363,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
|
||||
benchmark:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||
|
||||
test_client_on_docker:
|
||||
needs: [ build-cache ]
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -9,10 +9,10 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
|
||||
23
.github/workflows/test-infrastructure-files.yml
vendored
23
.github/workflows/test-infrastructure-files.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -34,6 +34,19 @@ jobs:
|
||||
--health-timeout 5s
|
||||
ports:
|
||||
- 5432:5432
|
||||
mysql:
|
||||
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
|
||||
env:
|
||||
MYSQL_USER: netbird
|
||||
MYSQL_PASSWORD: mysql
|
||||
MYSQL_ROOT_PASSWORD: mysqlroot
|
||||
MYSQL_DATABASE: netbird
|
||||
options: >-
|
||||
--health-cmd "mysqladmin ping --silent"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
ports:
|
||||
- 3306:3306
|
||||
steps:
|
||||
- name: Set Database Connection String
|
||||
run: |
|
||||
@@ -42,6 +55,11 @@ jobs:
|
||||
else
|
||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
||||
fi
|
||||
if [ "${{ matrix.store }}" == "mysql" ]; then
|
||||
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
|
||||
else
|
||||
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Install jq
|
||||
run: sudo apt-get install -y jq
|
||||
@@ -84,6 +102,7 @@ jobs:
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
- name: check values
|
||||
@@ -112,6 +131,7 @@ jobs:
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
|
||||
@@ -149,6 +169,7 @@ jobs:
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
# check relay values
|
||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
|
||||
@@ -179,6 +179,51 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
@@ -377,6 +422,18 @@ docker_manifests:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
|
||||
2
AUTHORS
2
AUTHORS
@@ -1,3 +1,3 @@
|
||||
Mikhail Bragin (https://github.com/braginini)
|
||||
Maycon Santos (https://github.com/mlsmaycon)
|
||||
Wiretrustee UG (haftungsbeschränkt)
|
||||
NetBird GmbH
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
We are incredibly thankful for the contributions we receive from the community.
|
||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||
as BSD-3 while allowing Wiretrustee to build a sustainable business.
|
||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||
|
||||
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables Wiretrustee to safely commercialize our products
|
||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables NetBird to safely commercialize our products
|
||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||
ability to use the project in their own projects or businesses, to republish modified
|
||||
source, or to completely fork the project.
|
||||
@@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
|
||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||
carefully review all the terms of the actual CLA before agreeing.
|
||||
|
||||
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
|
||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||
in commercial products.
|
||||
</li>
|
||||
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||
license to use that patent including within commercial products. You also agree that you
|
||||
have permission to grant this license.
|
||||
</li>
|
||||
@@ -45,7 +45,7 @@ more.
|
||||
# Why require a CLA?
|
||||
|
||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
|
||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||
products.
|
||||
|
||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||
@@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
|
||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||
require you to sign again.
|
||||
|
||||
# Legal Terms and Agreement
|
||||
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
|
||||
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||
your own Contributions for any other purpose.
|
||||
|
||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
|
||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||
You reserve all right, title, and interest in and to Your Contributions.
|
||||
|
||||
1. Definitions.
|
||||
|
||||
```
|
||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
|
||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||
entities that control, are controlled by, or are under common control with that entity are considered
|
||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||
@@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
```
|
||||
```
|
||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
|
||||
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
|
||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
```
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
|
||||
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
|
||||
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||
@@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
|
||||
with Wiretrustee.
|
||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||
with NetBird.
|
||||
|
||||
|
||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||
@@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
|
||||
|
||||
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
|
||||
4
LICENSE
4
LICENSE
@@ -1,6 +1,6 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
|
||||
Copyright (c) 2022 NetBird GmbH & AUTHORS
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
@@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
<br/>
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
|
||||
17
client/Dockerfile-rootless
Normal file
17
client/Dockerfile-rootless
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM alpine:3.21.0
|
||||
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
&& adduser -D -h /var/lib/netbird netbird
|
||||
WORKDIR /var/lib/netbird
|
||||
USER netbird:netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV NB_USE_NETSTACK_MODE=true
|
||||
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||
ENV NB_CONFIG=config.json
|
||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||
ENV NB_DISABLE_DNS=true
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
|
||||
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -38,6 +38,7 @@ const (
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,6 +74,7 @@ var (
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
log.Debug(err)
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
|
||||
31
client/cmd/system.go
Normal file
31
client/cmd/system.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cmd
|
||||
|
||||
// Flag constants for system configuration
|
||||
const (
|
||||
disableClientRoutesFlag = "disable-client-routes"
|
||||
disableServerRoutesFlag = "disable-server-routes"
|
||||
disableDNSFlag = "disable-dns"
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
)
|
||||
|
||||
var (
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
disableDNS bool
|
||||
disableFirewall bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add system flags to upCmd
|
||||
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
}
|
||||
137
client/cmd/trace.go
Normal file
137
client/cmd/trace.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var traceCmd = &cobra.Command{
|
||||
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||
Short: "Trace a packet through the firewall",
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(traceCmd)
|
||||
|
||||
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||
}
|
||||
|
||||
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||
direction := strings.ToLower(args[0])
|
||||
if direction != "in" && direction != "out" {
|
||||
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||
}
|
||||
|
||||
protocol := cmd.Flag("protocol").Value.String()
|
||||
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||
}
|
||||
|
||||
sport, err := cmd.Flags().GetUint16("sport")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid source port: %v", err)
|
||||
}
|
||||
dport, err := cmd.Flags().GetUint16("dport")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination port: %v", err)
|
||||
}
|
||||
|
||||
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||
if protocol != "icmp" {
|
||||
if sport == 0 {
|
||||
sport = uint16(rand.Intn(16383) + 49152)
|
||||
}
|
||||
if dport == 0 {
|
||||
dport = uint16(rand.Intn(16383) + 49152)
|
||||
}
|
||||
}
|
||||
|
||||
var tcpFlags *proto.TCPFlags
|
||||
if protocol == "tcp" {
|
||||
syn, _ := cmd.Flags().GetBool("syn")
|
||||
ack, _ := cmd.Flags().GetBool("ack")
|
||||
fin, _ := cmd.Flags().GetBool("fin")
|
||||
rst, _ := cmd.Flags().GetBool("rst")
|
||||
psh, _ := cmd.Flags().GetBool("psh")
|
||||
urg, _ := cmd.Flags().GetBool("urg")
|
||||
|
||||
tcpFlags = &proto.TCPFlags{
|
||||
Syn: syn,
|
||||
Ack: ack,
|
||||
Fin: fin,
|
||||
Rst: rst,
|
||||
Psh: psh,
|
||||
Urg: urg,
|
||||
}
|
||||
}
|
||||
|
||||
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||
SourceIp: args[1],
|
||||
DestinationIp: args[2],
|
||||
Protocol: protocol,
|
||||
SourcePort: uint32(sport),
|
||||
DestinationPort: uint32(dport),
|
||||
Direction: direction,
|
||||
TcpFlags: tcpFlags,
|
||||
IcmpType: &icmpType,
|
||||
IcmpCode: &icmpCode,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
|
||||
for _, stage := range resp.Stages {
|
||||
if stage.ForwardingDetails != nil {
|
||||
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||
} else {
|
||||
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||
}
|
||||
}
|
||||
|
||||
disposition := map[bool]string{
|
||||
true: "\033[32mALLOWED\033[0m", // Green
|
||||
false: "\033[31mDENIED\033[0m", // Red
|
||||
}[resp.FinalDisposition]
|
||||
|
||||
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||
}
|
||||
@@ -48,6 +48,7 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -147,6 +148,23 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -172,7 +190,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
return connectClient.Run()
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
@@ -264,6 +282,23 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
loginRequest.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
loginRequest.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
loginRequest.DisableDns = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
loginRequest.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
24
client/configs/configs.go
Normal file
24
client/configs/configs.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var StateDir string
|
||||
|
||||
func init() {
|
||||
StateDir = os.Getenv("NB_STATE_DIR")
|
||||
if StateDir != "" {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||
case "darwin", "linux":
|
||||
StateDir = "/var/lib/netbird"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
StateDir = "/var/db/netbird"
|
||||
}
|
||||
}
|
||||
@@ -14,13 +14,13 @@ import (
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, err := createNativeFirewall(iface, stateManager)
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
@@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
@@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -10,4 +12,6 @@ type IFaceMapper interface {
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package iptables
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
@@ -19,8 +19,7 @@ const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
@@ -84,28 +83,22 @@ func (m *aclManager) AddPeerFiltering(
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
chain := chainNameInputRules
|
||||
|
||||
var chain string
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
chain = chainNameOutputRules
|
||||
} else {
|
||||
chain = chainNameInputRules
|
||||
}
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
@@ -137,7 +130,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
||||
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
@@ -145,16 +138,22 @@ func (m *aclManager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
||||
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
@@ -197,6 +196,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
|
||||
if r.mangleSpecs != nil {
|
||||
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
@@ -214,28 +219,7 @@ func (m *aclManager) Reset() error {
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
rules := m.entries["OUTPUT"]
|
||||
for _, rule := range rules {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
@@ -295,12 +279,6 @@ func (m *aclManager) createDefaultChains() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// chain netbird-acl-output-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
@@ -329,8 +307,6 @@ func (m *aclManager) createDefaultChains() error {
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
|
||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
@@ -346,17 +322,10 @@ func (m *aclManager) seedInitialEntries() {
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
@@ -390,42 +359,26 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case firewall.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case firewall.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
return append(specs, "-j", actionToStr(action))
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
@@ -435,15 +388,15 @@ func actionToStr(action firewall.Action) string {
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
|
||||
switch {
|
||||
case ipsetName == "":
|
||||
return ""
|
||||
case sPort != "" && dPort != "":
|
||||
case sPort != nil && dPort != nil:
|
||||
return ipsetName + "-sport-dport"
|
||||
case sPort != "":
|
||||
case sPort != nil:
|
||||
return ipsetName + "-sport"
|
||||
case dPort != "":
|
||||
case dPort != nil:
|
||||
return ipsetName + "-dport"
|
||||
default:
|
||||
return ipsetName
|
||||
|
||||
@@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@@ -197,11 +196,10 @@ func (m *Manager) AllowNetbird() error {
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionIN,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
@@ -215,6 +213,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -68,27 +68,14 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{8043: 8046},
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
@@ -118,8 +96,8 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
@@ -135,9 +113,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
@@ -167,33 +142,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
@@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add route rule: %v", err)
|
||||
}
|
||||
|
||||
@@ -590,10 +599,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(p)
|
||||
portList[i] = strconv.Itoa(int(p))
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
@@ -239,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
dPort: &firewall.Port{Values: []uint16{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
@@ -252,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
@@ -285,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
@@ -297,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
@@ -307,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
|
||||
@@ -5,9 +5,10 @@ type Rule struct {
|
||||
ruleID string
|
||||
ipsetName string
|
||||
|
||||
specs []string
|
||||
ip string
|
||||
chain string
|
||||
specs []string
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -69,7 +69,6 @@ type Manager interface {
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@@ -100,6 +99,8 @@ type Manager interface {
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
SetLogLevel(log.Level)
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -30,7 +30,7 @@ type Port struct {
|
||||
IsRange bool
|
||||
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
Values []uint16
|
||||
}
|
||||
|
||||
// String interface implementation
|
||||
@@ -40,7 +40,11 @@ func (p *Port) String() string {
|
||||
if ports != "" {
|
||||
ports += ","
|
||||
}
|
||||
ports += strconv.Itoa(port)
|
||||
ports += strconv.Itoa(int(port))
|
||||
}
|
||||
if p.IsRange {
|
||||
ports = "range:" + ports
|
||||
}
|
||||
|
||||
return ports
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,8 +22,7 @@ import (
|
||||
const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
chainNameOutputRules = "netbird-acl-output-rules"
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
@@ -45,9 +44,9 @@ type AclManager struct {
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@@ -89,7 +88,6 @@ func (m *AclManager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@@ -104,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if r.nftSet == nil {
|
||||
err := m.rConn.DelRule(r.nftRule)
|
||||
if err != nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.GetRuleID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||
if !ok {
|
||||
err := m.rConn.DelRule(r.nftRule)
|
||||
if err != nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.GetRuleID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||
if err != nil {
|
||||
@@ -156,12 +163,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.rConn.DelRule(r.nftRule)
|
||||
if err != nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -214,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
expOut := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
// mask
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainOutputRules,
|
||||
Position: 0,
|
||||
Exprs: expOut,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
@@ -260,25 +239,33 @@ func (m *AclManager) Flush() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
|
||||
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
r.nftSet,
|
||||
r.ruleID,
|
||||
ip,
|
||||
nftRule: r.nftRule,
|
||||
mangleRule: r.mangleRule,
|
||||
nftSet: r.nftSet,
|
||||
ruleID: r.ruleID,
|
||||
ip: ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -310,9 +297,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrOffset := uint32(12)
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
addrOffset += 4 // is ipv4 address length
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
@@ -342,73 +326,100 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 0,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*sPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
expressions = append(expressions, applyPort(sPort, true)...)
|
||||
expressions = append(expressions, applyPort(dPort, false)...)
|
||||
|
||||
if dPort != nil && len(dPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*dPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
mainExpressions := slices.Clone(expressions)
|
||||
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
case firewall.ActionDrop:
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||
|
||||
var chain *nftables.Chain
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
chain = m.chainInputRules
|
||||
} else {
|
||||
chain = m.chainOutputRules
|
||||
}
|
||||
chain := m.chainInputRules
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
Exprs: mainExpressions,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
nftRule: nftRule,
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
m.rules[ruleId] = rule
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if m.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
return m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
@@ -419,15 +430,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
// chainNameOutputRules
|
||||
chain = m.createChain(chainNameOutputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
m.chainOutputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
@@ -461,7 +463,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
@@ -469,8 +471,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
|
||||
m.addPreroutingRule(preroutingChain)
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
@@ -480,43 +480,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: preroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
@@ -532,8 +495,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.chainInputRules.Name,
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -680,6 +642,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
@@ -696,7 +659,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if m.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -713,22 +676,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
split := bytes.Split(rule.UserData, []byte(" "))
|
||||
r, ok := m.rules[string(split[0])]
|
||||
if ok {
|
||||
*r.nftRule = *rule
|
||||
if mangle {
|
||||
*r.mangleRule = *rule
|
||||
} else {
|
||||
*r.nftRule = *rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(
|
||||
ip net.IP,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
@@ -744,12 +704,6 @@ func generatePeerRuleId(
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
|
||||
func encodePort(port firewall.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
return bs
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
|
||||
@@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -312,6 +318,11 @@ func (m *Manager) cleanupNetbirdTables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
|
||||
@@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -116,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
@@ -209,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"test rule",
|
||||
)
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
@@ -313,7 +291,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
@@ -329,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||
t.Helper()
|
||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||
|
||||
for i := range got {
|
||||
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||
_, wantIsCounter := want[i].(*expr.Counter)
|
||||
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||
continue
|
||||
}
|
||||
|
||||
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
rule = r.conn.AddRule(rule)
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
if action == firewall.ActionDrop {
|
||||
// TODO: Insert after the established rule
|
||||
rule = r.conn.InsertRule(rule)
|
||||
} else {
|
||||
rule = r.conn.AddRule(rule)
|
||||
}
|
||||
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
@@ -956,12 +962,12 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
||||
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpLte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
||||
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
@@ -980,7 +986,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
||||
Data: binaryutil.BigEndian.PutUint16(p),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
dPort: &firewall.Port{Values: []uint16{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
@@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
@@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
@@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
@@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []uint16{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
nftRule *nftables.Rule
|
||||
mangleRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
|
||||
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
established atomic.Bool
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (b *BaseConnTrack) IsEstablished() bool {
|
||||
return b.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
||||
b.established.Store(state)
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
|
||||
@@ -3,8 +3,14 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
@@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
|
||||
})
|
||||
|
||||
}
|
||||
func BenchmarkAtomicOperations(b *testing.B) {
|
||||
conn := &BaseConnTrack{}
|
||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.IsEstablished()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SetEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.SetEstablished(i%2 == 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.GetLastSeen()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
@@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@@ -42,12 +45,13 @@ type ICMPTracker struct {
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
@@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
switch icmpType {
|
||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||
return true
|
||||
case uint8(layers.ICMPv4TypeEchoReply):
|
||||
// continue processing
|
||||
default:
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
@@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
||||
@@ -5,7 +5,10 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,12 +64,24 @@ type TCPConnKey struct {
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
State TCPState
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
@@ -76,8 +91,9 @@ type TCPTracker struct {
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
@@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
@@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
@@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
@@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
|
||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
|
||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
type UDPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@@ -29,12 +32,13 @@ type UDPTracker struct {
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
@@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %v", conn)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
@@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
@@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout)
|
||||
tracker := NewUDPTracker(tt.timeout, logger)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
@@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.True(t, conn.IsEstablished())
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1 * time.Second)
|
||||
tracker := NewUDPTracker(1*time.Second, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
@@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
||||
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (e *endpoint) IsAttached() bool {
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityNone
|
||||
}
|
||||
|
||||
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (e *endpoint) Wait() {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultReceiveWindow = 32768
|
||||
defaultMaxInFlight = 1024
|
||||
iosReceiveWindow = 16384
|
||||
iosMaxInFlight = 256
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
maxInFlight := defaultMaxInFlight
|
||||
if runtime.GOOS == "ios" {
|
||||
receiveWindow = iosReceiveWindow
|
||||
maxInFlight = iosMaxInFlight
|
||||
}
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
if f.udpForwarder != nil {
|
||||
f.udpForwarder.Stop()
|
||||
}
|
||||
|
||||
f.stack.Close()
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
||||
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
return true
|
||||
}
|
||||
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
|
||||
// Complete the handshake
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
288
client/firewall/uspfilter/forwarder/udp.go
Normal file
288
client/firewall/uspfilter/forwarder/udp.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
udpTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type udpPacketConn struct {
|
||||
conn *gonet.UDPConn
|
||||
outConn net.Conn
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
ep tcpip.Endpoint
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
id stack.TransportEndpointID
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, mtu)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
go f.cleanup()
|
||||
return f
|
||||
}
|
||||
|
||||
// Stop stops the UDP forwarder and all active connections
|
||||
func (f *udpForwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
conn.ep.Close()
|
||||
delete(f.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle UDP connections
|
||||
func (f *udpForwarder) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
var idleConns []idleConn
|
||||
|
||||
f.RLock()
|
||||
for id, conn := range f.conns {
|
||||
if conn.getIdleDuration() > udpTimeout {
|
||||
idleConns = append(idleConns, idleConn{id, conn})
|
||||
}
|
||||
}
|
||||
f.RUnlock()
|
||||
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
|
||||
idle.conn.ep.Close()
|
||||
|
||||
f.Lock()
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
|
||||
f.udpForwarder.RLock()
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
ep: ep,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
// Double-check no connection was created while we were setting up
|
||||
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
c.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
|
||||
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
134
client/firewall/uspfilter/localip.go
Normal file
134
client/firewall/uspfilter/localip.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
return &localIPManager{}
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
ipStr := ip.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.ipv4Bitmap = newIPv4Bitmap
|
||||
m.mu.Unlock()
|
||||
|
||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
270
client/firewall/uspfilter/localip_test.go
Normal file
270
client/firewall/uspfilter/localip_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
|
||||
err := manager.UpdateLocalIPs(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.IsLocalIP(tt.testIP)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
mock := &IFaceMock{}
|
||||
|
||||
// Get actual local interfaces
|
||||
interfaces, err := net.Interfaces()
|
||||
require.NoError(t, err)
|
||||
|
||||
var tests []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}
|
||||
|
||||
// Add all local interface IPs to test cases
|
||||
for _, iface := range interfaces {
|
||||
addrs, err := iface.Addrs()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
tests = append(tests, struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
ip: ip4.String(),
|
||||
expected: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add some external IPs as negative test cases
|
||||
externalIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"208.67.222.222",
|
||||
}
|
||||
for _, ip := range externalIPs {
|
||||
tests = append(tests, struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
ip: ip,
|
||||
expected: false,
|
||||
})
|
||||
}
|
||||
|
||||
require.NotEmpty(t, tests, "No test cases generated")
|
||||
|
||||
err = manager.UpdateLocalIPs(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MapImplementation is a version using map[string]struct{}
|
||||
type MapImplementation struct {
|
||||
localIPs map[string]struct{}
|
||||
}
|
||||
|
||||
func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces := make([]net.IP, 16)
|
||||
for i := range interfaces {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
|
||||
// Setup map version
|
||||
mapManager := &MapImplementation{
|
||||
localIPs: make(map[string]struct{}),
|
||||
}
|
||||
for _, ip := range interfaces[:8] {
|
||||
mapManager.localIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWGPosition(b *testing.B) {
|
||||
wgIP := net.ParseIP("10.10.0.1")
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
}
|
||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
196
client/firewall/uspfilter/log/log.go
Normal file
196
client/firewall/uspfilter/log/log.go
Normal file
@@ -0,0 +1,196 @@
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
LevelPanic Level = iota
|
||||
LevelFatal
|
||||
LevelError
|
||||
LevelWarn
|
||||
LevelInfo
|
||||
LevelDebug
|
||||
LevelTrace
|
||||
)
|
||||
|
||||
var levelStrings = map[Level]string{
|
||||
LevelPanic: "PANC",
|
||||
LevelFatal: "FATL",
|
||||
LevelError: "ERRO",
|
||||
LevelWarn: "WARN",
|
||||
LevelInfo: "INFO",
|
||||
LevelDebug: "DEBG",
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
logrusLevel := logrusLogger.GetLevel()
|
||||
l.level.Store(uint32(logrusLevel))
|
||||
level := levelStrings[Level(logrusLevel)]
|
||||
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.worker()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
_, _ = l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
_, _ = l.output.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the logger
|
||||
func (l *Logger) Stop(ctx context.Context) error {
|
||||
done := make(chan struct{})
|
||||
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.shutdown)
|
||||
})
|
||||
|
||||
go func() {
|
||||
l.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -2,22 +2,22 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction firewall.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
comment string
|
||||
|
||||
@@ -25,6 +25,21 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
func (r *PeerRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *RouteRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
390
client/firewall/uspfilter/tracer.go
Normal file
390
client/firewall/uspfilter/tracer.go
Normal file
@@ -0,0 +1,390 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
)
|
||||
|
||||
type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
}[s]
|
||||
}
|
||||
|
||||
type ForwarderAction struct {
|
||||
Action string
|
||||
RemoteAddr string
|
||||
Error error
|
||||
}
|
||||
|
||||
type TraceResult struct {
|
||||
Timestamp time.Time
|
||||
Stage PacketStage
|
||||
Message string
|
||||
Allowed bool
|
||||
ForwarderAction *ForwarderAction
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
Direction fw.RuleDirection
|
||||
Results []TraceResult
|
||||
}
|
||||
|
||||
type TCPState struct {
|
||||
SYN bool
|
||||
ACK bool
|
||||
FIN bool
|
||||
RST bool
|
||||
PSH bool
|
||||
URG bool
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Direction fw.RuleDirection
|
||||
PayloadSize int
|
||||
TCPState *TCPState
|
||||
}
|
||||
|
||||
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||
t.Results = append(t.Results, TraceResult{
|
||||
Timestamp: time.Now(),
|
||||
Stage: stage,
|
||||
Message: message,
|
||||
Allowed: allowed,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||
t.Results = append(t.Results, TraceResult{
|
||||
Timestamp: time.Now(),
|
||||
Stage: stage,
|
||||
Message: message,
|
||||
Allowed: allowed,
|
||||
ForwarderAction: action,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
ip := p.buildIPLayer()
|
||||
pktLayers := []gopacket.SerializableLayer{ip}
|
||||
|
||||
transportLayer, err := p.buildTransportLayer(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktLayers = append(pktLayers, transportLayer...)
|
||||
|
||||
if p.PayloadSize > 0 {
|
||||
payload := make([]byte, p.PayloadSize)
|
||||
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||
}
|
||||
|
||||
return serializePacket(pktLayers)
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
return &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
switch p.Protocol {
|
||||
case "tcp":
|
||||
return p.buildTCPLayer(ip)
|
||||
case "udp":
|
||||
return p.buildUDPLayer(ip)
|
||||
case "icmp":
|
||||
return p.buildICMPLayer()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(p.SrcPort),
|
||||
DstPort: layers.TCPPort(p.DstPort),
|
||||
Window: 65535,
|
||||
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||
RST: p.TCPState != nil && p.TCPState.RST,
|
||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||
URG: p.TCPState != nil && p.TCPState.URG,
|
||||
}
|
||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{tcp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(p.SrcPort),
|
||||
DstPort: layers.UDPPort(p.DstPort),
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{udp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||
icmp.Id = uint16(1)
|
||||
icmp.Seq = uint16(1)
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp}, nil
|
||||
}
|
||||
|
||||
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||
switch protocol {
|
||||
case fw.ProtocolTCP:
|
||||
return int(layers.IPProtocolTCP)
|
||||
case fw.ProtocolUDP:
|
||||
return int(layers.IPProtocolUDP)
|
||||
case fw.ProtocolICMP:
|
||||
return int(layers.IPProtocolICMPv4)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||
packetData, err := builder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build packet: %w", err)
|
||||
}
|
||||
|
||||
return m.TracePacket(packetData, builder.Direction), nil
|
||||
}
|
||||
|
||||
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
trace := &PacketTrace{Direction: direction}
|
||||
|
||||
// Initial packet decoding
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||
return trace
|
||||
}
|
||||
|
||||
// Extract base packet info
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
trace.SourceIP = srcIP
|
||||
trace.DestinationIP = dstIP
|
||||
|
||||
// Determine protocol and ports
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
trace.Protocol = "TCP"
|
||||
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
trace.Protocol = "UDP"
|
||||
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
trace.Protocol = "ICMP"
|
||||
}
|
||||
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
return m.traceOutbound(packetData, trace)
|
||||
}
|
||||
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
trace.AddResult(StageConntrack, msg, true)
|
||||
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||
return true
|
||||
}
|
||||
trace.AddResult(StageConntrack, msg, false)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
msg := "Matched existing connection state"
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||
flags&conntrack.TCPSyn != 0,
|
||||
flags&conntrack.TCPAck != 0,
|
||||
flags&conntrack.TCPRst != 0,
|
||||
flags&conntrack.TCPFin != 0)
|
||||
case layers.LayerTypeICMPv4:
|
||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
msg := "Allowed by peer ACL rules"
|
||||
if blocked {
|
||||
msg = "Blocked by peer ACL rules"
|
||||
}
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
|
||||
if m.netstack {
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
}
|
||||
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
msg := "Allowed by route ACLs"
|
||||
if !allowed {
|
||||
msg = "Blocked by route ACLs"
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||
fwdAction := &ForwarderAction{
|
||||
Action: action,
|
||||
RemoteAddr: remoteAddr,
|
||||
}
|
||||
trace.AddResultWithForwarder(StageForwarding,
|
||||
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||
}
|
||||
return trace
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -14,44 +17,81 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
type RouteRules []RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
return -1
|
||||
}
|
||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.id, b.id)
|
||||
})
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
stateful bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
netstack bool
|
||||
// indicates whether we forward local traffic to the native stack
|
||||
localForwarding bool
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
logger *nblog.Logger
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -68,22 +108,32 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr.nativeFirewall = nativeFirewall
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||
}
|
||||
enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding))
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@@ -99,52 +149,161 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
return d
|
||||
},
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
stateful: !disableConntrack,
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
netstack: netstack.IsEnabled(),
|
||||
// default true for non-netstack, for netstack only if explicitly enabled
|
||||
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
|
||||
}
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
m.determineRouting(iface, disableServerRoutes)
|
||||
|
||||
if err := m.blockInvalidRouted(iface); err != nil {
|
||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return fmt.Errorf("block wg nte : %w", err)
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) {
|
||||
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
|
||||
forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter))
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding ||
|
||||
m.routingEnabled && !m.nativeRouter {
|
||||
|
||||
m.initForwarder(iface)
|
||||
}
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder(iface common.IFaceMapper) {
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := iface.GetWGDevice()
|
||||
if intf == nil {
|
||||
log.Info("forwarding not supported")
|
||||
m.routingEnabled = false
|
||||
return
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(iface, m.logger, m.netstack)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create forwarder: %v", err)
|
||||
m.routingEnabled = false
|
||||
return
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
}
|
||||
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
|
||||
// userspace routed packets are always SNATed to the inbound direction
|
||||
// TODO: implement outbound SNAT
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
@@ -156,17 +315,15 @@ func (m *Manager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
@@ -179,13 +336,8 @@ func (m *Manager) AddPeerFiltering(
|
||||
r.matchByIP = false
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) == 1 {
|
||||
r.sPort = uint16(sPort.Values[0])
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) == 1 {
|
||||
r.dPort = uint16(dPort.Values[0])
|
||||
}
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
@@ -202,33 +354,64 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
id: ruleID,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.GetRuleID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -236,24 +419,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
r, ok := rule.(*PeerRule)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if r.direction == firewall.RuleDirectionIN {
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -276,10 +450,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules)
|
||||
return m.dropFilter(packetData)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@@ -300,18 +478,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always process UDP hooks
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
// Track UDP state only if enabled
|
||||
if m.stateful {
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
// Track other protocols only if stateful mode is enabled
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
@@ -319,6 +490,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -380,7 +556,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
@@ -400,8 +576,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@@ -409,25 +586,120 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
m.logger.Trace("Invalid packet structure")
|
||||
return true
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check connection state only if enabled
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
return m.applyRules(srcIP, packetData, rules, d)
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
if !m.localForwarding {
|
||||
m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack {
|
||||
m.handleNetstackLocalTraffic(packetData)
|
||||
|
||||
// don't process this packet further
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
|
||||
if m.forwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get protocol and ports for route ACL check
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
// Check route ACLs
|
||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0, 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
@@ -443,10 +715,6 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
@@ -481,7 +749,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
if m.isSpecialICMP(d) {
|
||||
return false
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return filter
|
||||
}
|
||||
@@ -498,7 +781,24 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
|
||||
return true
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
if rulePort == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if rulePort.IsRange {
|
||||
return packetPort >= rulePort.Values[0] && packetPort <= rulePort.Values[1]
|
||||
}
|
||||
|
||||
for _, p := range rulePort.Values {
|
||||
if p == packetPort {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
@@ -515,13 +815,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
@@ -531,13 +825,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
return rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
@@ -547,6 +835,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
return false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
return rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
sourceMatched := false
|
||||
for _, src := range rule.sources {
|
||||
if src.Contains(srcAddr) {
|
||||
sourceMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
@@ -558,13 +891,12 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: dPort,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
@@ -575,14 +907,13 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
@@ -594,21 +925,31 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(level log.Level) {
|
||||
if m.logger != nil {
|
||||
m.logger.SetLevel(nblog.Level(level))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
//go:build uspbench
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -91,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
|
||||
fw.ActionAccept, "", "allow all")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -112,9 +115,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{1024 + i}},
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -126,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
|
||||
fw.ActionDrop, "", "default drop")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, manager.incomingRules)
|
||||
manager.dropFilter(testIn)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
manager.processOutgoingHooks(syn)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
@@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -588,9 +591,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
@@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -679,9 +682,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
@@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -797,9 +800,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
@@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -884,9 +887,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
@@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func BenchmarkRouteACLs(b *testing.B) {
|
||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||
|
||||
// Add several route rules to simulate real-world scenario
|
||||
rules := []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
port *fw.Port
|
||||
}{
|
||||
{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
port: &fw.Port{Values: []uint16{80, 443}},
|
||||
},
|
||||
{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/12"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: fw.ProtocolICMP,
|
||||
},
|
||||
{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: fw.ProtocolUDP,
|
||||
port: &fw.Port{Values: []uint16{53}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
nil,
|
||||
r.port,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases that exercise different matching scenarios
|
||||
cases := []struct {
|
||||
srcIP string
|
||||
dstIP string
|
||||
proto fw.Protocol
|
||||
dstPort uint16
|
||||
}{
|
||||
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
|
||||
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
|
||||
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
|
||||
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1014
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
1014
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,17 +9,38 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return i.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||
if i.GetDeviceFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return i.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
@@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -69,12 +90,11 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -96,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -104,38 +124,16 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP("192.168.1.1")
|
||||
proto = fw.ProtocolTCP
|
||||
port = &fw.Port{Values: []int{80}}
|
||||
direction = fw.RuleDirectionIN
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule {
|
||||
err = m.DeletePeerRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
@@ -189,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
@@ -217,18 +215,14 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if tt.expDir != addedRule.direction {
|
||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
@@ -242,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -250,12 +244,11 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -275,9 +268,18 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -289,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -327,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), m.outgoingRules) {
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@@ -346,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface)
|
||||
manager, err := Create(iface, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -392,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -400,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
@@ -478,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -492,12 +493,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -509,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -518,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -639,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -710,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "wt0"
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "utun100"
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -15,4 +17,5 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
@@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
|
||||
@@ -64,7 +64,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
log.Debug("Attaching to interface")
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -33,8 +34,6 @@ type TunKernelDevice struct {
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
checkUser()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &TunKernelDevice{
|
||||
ctx: ctx,
|
||||
@@ -153,6 +152,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Device returns the wireguard device, not applicable for kernel devices
|
||||
func (t *TunKernelDevice) Device() *device.Device {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
|
||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunNetstackDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@@ -32,8 +30,6 @@ type USPDevice struct {
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
checkUser()
|
||||
|
||||
return &USPDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -128,18 +124,14 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *USPDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *USPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
return link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func checkUser() {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
euid := os.Geteuid()
|
||||
if euid != 0 {
|
||||
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||
if t.nativeTunDevice == nil {
|
||||
return "", fmt.Errorf("interface has not been initialized yet")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -13,4 +15,5 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
}
|
||||
|
||||
@@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
|
||||
return fmt.Errorf("set interface addr: %w", err)
|
||||
}
|
||||
|
||||
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
||||
return w.tun.FilteredDevice()
|
||||
}
|
||||
|
||||
// GetWGDevice returns the WireGuard device
|
||||
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
return w.tun.Device()
|
||||
}
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return w.configurer.GetStats(peerKey)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@@ -29,6 +30,7 @@ type MockWGIface struct {
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
@@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||
return m.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@@ -32,5 +33,6 @@ type IWGIface interface {
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@@ -30,6 +31,7 @@ type IWGIface interface {
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,10 @@ func IsEnabled() bool {
|
||||
|
||||
func ListenAddr() string {
|
||||
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
|
||||
if sPort == "" {
|
||||
return listenAddr(DefaultSocks5Port)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(sPort)
|
||||
if err != nil {
|
||||
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)
|
||||
|
||||
@@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
@@ -268,13 +268,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
if !portInfoEmpty(r.PortInfo) {
|
||||
port = convertPortInfo(r.PortInfo)
|
||||
} else if r.Port != "" {
|
||||
// old version of management, single port
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
port = &firewall.Port{
|
||||
Values: []int{value},
|
||||
Values: []uint16{uint16(value)},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
@@ -300,6 +305,22 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
if portInfo == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
switch portInfo.GetPortSelection().(type) {
|
||||
case *mgmProto.PortInfo_Port:
|
||||
return portInfo.GetPort() == 0
|
||||
case *mgmProto.PortInfo_Range_:
|
||||
r := portInfo.GetRange()
|
||||
return r == nil || r.Start == 0 || r.End == 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
@@ -308,25 +329,12 @@ func (d *DefaultManager) addInRules(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
@@ -337,25 +345,16 @@ func (d *DefaultManager) addOutRules(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
@@ -508,7 +507,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
@@ -559,14 +558,14 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||
|
||||
if portInfo.GetPort() != 0 {
|
||||
return &firewall.Port{
|
||||
Values: []int{int(portInfo.GetPort())},
|
||||
Values: []uint16{uint16(int(portInfo.GetPort()))},
|
||||
}
|
||||
}
|
||||
|
||||
if portInfo.GetRange() != nil {
|
||||
return &firewall.Port{
|
||||
IsRange: true,
|
||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||
Values: []uint16{uint16(portInfo.GetRange().Start), uint16(portInfo.GetRange().End)},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
@@ -119,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap)
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
if len(acl.peerRulesPairs) != 1 {
|
||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
})
|
||||
@@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
@@ -356,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
if len(acl.peerRulesPairs) != 3 {
|
||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||
}
|
||||
|
||||
// GetDevice mocks base method.
|
||||
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDevice")
|
||||
ret0, _ := ret[0].(*device.FilteredDevice)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDevice indicates an expected call of GetDevice.
|
||||
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
|
||||
}
|
||||
|
||||
// GetWGDevice mocks base method.
|
||||
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWGDevice")
|
||||
ret0, _ := ret[0].(*wgdevice.Device)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetWGDevice indicates an expected call of GetWGDevice.
|
||||
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,7 +13,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
@@ -56,6 +61,18 @@ func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*Devi
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
} else {
|
||||
log.Debug("Using system certificate pool.")
|
||||
}
|
||||
|
||||
httpTransport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
|
||||
@@ -61,6 +61,13 @@ type ConfigInput struct {
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
|
||||
DisableClientRoutes *bool
|
||||
DisableServerRoutes *bool
|
||||
DisableDNS *bool
|
||||
DisableFirewall *bool
|
||||
|
||||
BlockLANAccess *bool
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -78,6 +85,14 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
|
||||
BlockLANAccess bool
|
||||
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
@@ -402,7 +417,56 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes {
|
||||
if *input.DisableClientRoutes {
|
||||
log.Infof("disabling client routes")
|
||||
} else {
|
||||
log.Infof("enabling client routes")
|
||||
}
|
||||
config.DisableClientRoutes = *input.DisableClientRoutes
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes {
|
||||
if *input.DisableServerRoutes {
|
||||
log.Infof("disabling server routes")
|
||||
} else {
|
||||
log.Infof("enabling server routes")
|
||||
}
|
||||
config.DisableServerRoutes = *input.DisableServerRoutes
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS {
|
||||
if *input.DisableDNS {
|
||||
log.Infof("disabling DNS configuration")
|
||||
} else {
|
||||
log.Infof("enabling DNS configuration")
|
||||
}
|
||||
config.DisableDNS = *input.DisableDNS
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall {
|
||||
if *input.DisableFirewall {
|
||||
log.Infof("disabling firewall configuration")
|
||||
} else {
|
||||
log.Infof("enabling firewall configuration")
|
||||
}
|
||||
config.DisableFirewall = *input.DisableFirewall
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
|
||||
if *input.BlockLANAccess {
|
||||
log.Infof("blocking LAN access")
|
||||
} else {
|
||||
log.Infof("allowing LAN access")
|
||||
}
|
||||
config.BlockLANAccess = *input.BlockLANAccess
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
|
||||
@@ -59,13 +59,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run() error {
|
||||
return c.run(MobileDependency{}, nil, nil)
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, probes, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@@ -84,7 +79,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -102,10 +97,10 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
@@ -182,8 +177,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}
|
||||
}()
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
@@ -204,8 +199,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
signalURL := fmt.Sprintf("%s://%s",
|
||||
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||
strings.ToLower(loginResp.GetNetbirdConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetNetbirdConfig().GetSignal().GetUri(),
|
||||
)
|
||||
|
||||
c.statusRecorder.UpdateSignalAddress(signalURL)
|
||||
@@ -216,8 +211,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
// with the global Netbird config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetNetbirdConfig(), myPrivateKey)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -261,7 +256,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -316,7 +311,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}
|
||||
|
||||
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
|
||||
relayCfg := loginResp.GetNetbirdConfig().GetRelay()
|
||||
if relayCfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -382,8 +377,7 @@ func (c *ConnectClient) isContextCancelled() bool {
|
||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||
// When enabled, the last received network map will be stored and can be retrieved
|
||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||
// network map will be cleared. This functionality is primarily used for debugging
|
||||
// and should not be enabled during normal operation.
|
||||
// network map will be cleared.
|
||||
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
||||
c.engineMutex.Lock()
|
||||
c.persistNetworkMap = enabled
|
||||
@@ -416,6 +410,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes,
|
||||
DisableDNS: config.DisableDNS,
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -439,7 +440,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
}
|
||||
|
||||
// connectToSignal creates Signal Service client and established a connection
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
var sigTLSEnabled bool
|
||||
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
||||
sigTLSEnabled = true
|
||||
@@ -456,8 +457,8 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
||||
return signalClient, nil
|
||||
}
|
||||
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
@@ -465,6 +466,15 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
18
client/internal/dns/consts.go
Normal file
18
client/internal/dns/consts.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var fileUncleanShutdownResolvConfLocation string
|
||||
|
||||
func init() {
|
||||
fileUncleanShutdownResolvConfLocation = os.Getenv("NB_UNCLEAN_SHUTDOWN_RESOLV_FILE")
|
||||
if fileUncleanShutdownResolvConfLocation == "" {
|
||||
fileUncleanShutdownResolvConfLocation = filepath.Join(configs.StateDir, "resolv.conf")
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
)
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
)
|
||||
@@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
origPattern := pattern
|
||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||
if isWildcard {
|
||||
pattern = pattern[2:]
|
||||
}
|
||||
pattern = dns.Fqdn(pattern)
|
||||
origPattern = dns.Fqdn(origPattern)
|
||||
|
||||
// First remove any existing handler with same original pattern and priority
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
@@ -106,17 +105,30 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
MatchSubdomains: matchSubdomains,
|
||||
}
|
||||
|
||||
// Insert handler in priority order
|
||||
pos := 0
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||
for i, h := range c.handlers {
|
||||
if h.Priority < priority {
|
||||
pos = i
|
||||
break
|
||||
// prio first
|
||||
if h.Priority < newEntry.Priority {
|
||||
return i
|
||||
}
|
||||
|
||||
// domain specificity next
|
||||
if h.Priority == newEntry.Priority {
|
||||
newDots := strings.Count(newEntry.Pattern, ".")
|
||||
existingDots := strings.Count(h.Pattern, ".")
|
||||
if newDots > existingDots {
|
||||
return i
|
||||
}
|
||||
}
|
||||
pos = i + 1
|
||||
}
|
||||
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
// add at end
|
||||
return len(c.handlers)
|
||||
}
|
||||
|
||||
// RemoveHandler removes a handler for the given pattern and priority
|
||||
@@ -126,10 +138,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
// Find and remove handlers matching both original pattern and priority
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if entry.OrigPattern == pattern && entry.Priority == priority {
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
@@ -144,9 +156,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
for _, entry := range c.handlers {
|
||||
if entry.Pattern == pattern {
|
||||
if strings.EqualFold(entry.Pattern, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -158,7 +170,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
qname := r.Question[0].Name
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
log.Tracef("handling DNS request for domain=%s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
@@ -187,9 +199,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
// Otherwise require exact match
|
||||
if entry.MatchSubdomains {
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
} else {
|
||||
matched = qname == entry.Pattern
|
||||
matched = strings.EqualFold(qname, entry.Pattern)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -507,5 +507,326 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
}
|
||||
|
||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
addHandlers []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}
|
||||
query string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "case insensitive exact match",
|
||||
scenario: "handler registered lowercase, query uppercase",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "case insensitive wildcard match",
|
||||
scenario: "handler registered mixed case wildcard, query different case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "sub.EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple handlers different case same domain",
|
||||
scenario: "second handler should replace first despite case difference",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "ExAmPlE.cOm.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "subdomain matching case insensitive",
|
||||
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||
},
|
||||
query: "SUB.EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone case insensitive",
|
||||
scenario: "root zone handler should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{".", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple handlers different priority",
|
||||
scenario: "should call higher priority handler despite case differences",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||
|
||||
// Add handlers according to test case
|
||||
for _, h := range tt.addHandlers {
|
||||
var handler dns.Handler
|
||||
pattern := h.pattern // capture pattern for closure
|
||||
|
||||
if h.subdomains {
|
||||
subHandler := &nbdns.MockSubdomainHandler{
|
||||
Subdomains: true,
|
||||
}
|
||||
if h.shouldMatch {
|
||||
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
handlerCalls[pattern] = true
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
handler = subHandler
|
||||
} else {
|
||||
mockHandler := &nbdns.MockHandler{}
|
||||
if h.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
handlerCalls[pattern] = true
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
handler = mockHandler
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||
|
||||
// Verify each handler was called exactly as expected
|
||||
for _, h := range tt.addHandlers {
|
||||
wasCalled := handlerCalls[h.pattern]
|
||||
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||
"Handler for pattern %q was %s when it should%s have been",
|
||||
h.pattern,
|
||||
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||
}
|
||||
|
||||
// Verify total number of calls
|
||||
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||
"Wrong number of total handler calls")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
ops []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}
|
||||
query string
|
||||
expectedMatch string
|
||||
}{
|
||||
{
|
||||
name: "more specific domain matches first",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "more specific domain matches first, both match subdomains",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "maintain specificity order after removal",
|
||||
scenario: "after removing most specific, should fall back to less specific",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "test.sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "priority overrides specificity",
|
||||
scenario: "less specific domain with higher priority should match first",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "example.com.",
|
||||
},
|
||||
{
|
||||
name: "equal priority respects specificity",
|
||||
scenario: "with equal priority, more specific domain should match",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "specific matches before wildcard",
|
||||
scenario: "specific domain should match before wildcard at same priority",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
|
||||
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
||||
|
||||
for _, op := range tt.ops {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||
handlers[op.pattern] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
}
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// Setup handler expectations
|
||||
for pattern, handler := range handlers {
|
||||
if pattern == tt.expectedMatch {
|
||||
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetReply(r)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
for pattern, handler := range handlers {
|
||||
if pattern == tt.expectedMatch {
|
||||
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
|
||||
} else {
|
||||
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
type noopHostConfigurator struct{}
|
||||
|
||||
func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ const (
|
||||
arraySymbol = "* "
|
||||
digitSymbol = "# "
|
||||
scutilPath = "/usr/sbin/scutil"
|
||||
dscacheutilPath = "/usr/bin/dscacheutil"
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
@@ -106,6 +107,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -123,6 +128,10 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -316,6 +325,21 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
return primaryService, router, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) flushDNSCache() error {
|
||||
cmd := exec.Command(dscacheutilPath, "-flushcache")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("flush DNS cache: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||
|
||||
@@ -48,11 +48,17 @@ type restoreHostManager interface {
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
return newHostManagerFromType(wgInterface, osManager)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -15,23 +16,64 @@ import (
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
|
||||
// resolvconfType represents the type of resolvconf implementation
|
||||
type resolvconfType int
|
||||
|
||||
func (r resolvconfType) String() string {
|
||||
switch r {
|
||||
case typeOpenresolv:
|
||||
return "openresolv"
|
||||
case typeResolvconf:
|
||||
return "resolvconf"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
typeOpenresolv resolvconfType = iota
|
||||
typeResolvconf
|
||||
)
|
||||
|
||||
type resolvconf struct {
|
||||
ifaceName string
|
||||
implType resolvconfType
|
||||
|
||||
originalSearchDomains []string
|
||||
originalNameServers []string
|
||||
othersConfigs []string
|
||||
}
|
||||
|
||||
// supported "openresolv" only
|
||||
func detectResolvconfType() (resolvconfType, error) {
|
||||
cmd := exec.Command(resolvconfCommand, "--version")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return typeOpenresolv, fmt.Errorf("failed to determine resolvconf type: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(string(out), "openresolv") {
|
||||
return typeOpenresolv, nil
|
||||
}
|
||||
return typeResolvconf, nil
|
||||
}
|
||||
|
||||
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
|
||||
resolvConfEntries, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
|
||||
implType, err := detectResolvconfType()
|
||||
if err != nil {
|
||||
log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err)
|
||||
implType = typeOpenresolv
|
||||
} else {
|
||||
log.Infof("detected resolvconf type: %v", implType)
|
||||
}
|
||||
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface,
|
||||
implType: implType,
|
||||
originalSearchDomains: resolvConfEntries.searchDomains,
|
||||
originalNameServers: resolvConfEntries.nameServers,
|
||||
othersConfigs: resolvConfEntries.others,
|
||||
@@ -80,8 +122,15 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
|
||||
}
|
||||
|
||||
func (r *resolvconf) restoreHostDNS() error {
|
||||
// openresolv only, debian resolvconf doesn't support "-f"
|
||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case typeOpenresolv:
|
||||
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
case typeResolvconf:
|
||||
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||
}
|
||||
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||
@@ -91,10 +140,21 @@ func (r *resolvconf) restoreHostDNS() error {
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
// openresolv only, debian resolvconf doesn't support "-x"
|
||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch r.implType {
|
||||
case typeOpenresolv:
|
||||
// OpenResolv supports exclusive mode with -x
|
||||
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
case typeResolvconf:
|
||||
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
|
||||
default:
|
||||
return fmt.Errorf("unsupported resolvconf type: %v", r.implType)
|
||||
}
|
||||
|
||||
cmd.Stdin = &content
|
||||
_, err := cmd.Output()
|
||||
out, err := cmd.Output()
|
||||
log.Tracef("resolvconf output: %s", out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -47,6 +48,7 @@ type registeredHandlerMap map[string]handlerWithStop
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
disableSys bool
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
@@ -84,7 +86,14 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -101,7 +110,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
@@ -112,9 +121,10 @@ func NewDefaultServerPermanentUpstream(
|
||||
config nbdns.Config,
|
||||
listener listener.NetworkChangeListener,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@@ -131,17 +141,26 @@ func NewDefaultServerIos(
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||
func newDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
dnsService service,
|
||||
statusRecorder *peer.Status,
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
@@ -220,6 +239,16 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
}
|
||||
|
||||
s.stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// use noop host manager if requested or running in netstack mode.
|
||||
// Netstack mode currently doesn't have a way to receive DNS requests.
|
||||
// TODO: Use listener on localhost in netstack mode when running as root.
|
||||
if s.disableSys || netstack.IsEnabled() {
|
||||
log.Info("system DNS is disabled, not setting up host manager")
|
||||
s.hostManager = &noopHostConfigurator{}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.hostManager, err = s.initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
@@ -268,47 +297,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
if s.ctx.Err() != nil {
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
}
|
||||
|
||||
if s.hostManager == nil {
|
||||
return fmt.Errorf("dns service is not initialized yet")
|
||||
}
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
if s.hostManager == nil {
|
||||
return fmt.Errorf("dns service is not initialized yet")
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return fmt.Errorf("apply configuration: %w", err)
|
||||
}
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return fmt.Errorf("apply configuration: %w", err)
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) SearchDomains() []string {
|
||||
@@ -627,8 +656,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.currentConfig.RouteAll = true
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
|
||||
if s.hostManager != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
|
||||
@@ -294,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -403,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -498,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
@@ -633,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
|
||||
var dnsList []string
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -657,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -749,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -849,7 +849,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface)
|
||||
pf, err := uspfilter.Create(wgIface, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -81,9 +81,14 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
func (h *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
Values: []uint16{ListenPort},
|
||||
}
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
|
||||
if h.firewall == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user